mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Change onp/np to np/jnp in docs & notebooks (#3760)
This commit is contained in:
parent
150d028d9d
commit
05904faf0f
40
README.md
40
README.md
@ -52,18 +52,18 @@ bugs](https://github.com/google/jax/issues), and letting us know what you
|
||||
think!
|
||||
|
||||
```python
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import grad, jit, vmap
|
||||
|
||||
def predict(params, inputs):
|
||||
for W, b in params:
|
||||
outputs = np.dot(inputs, W) + b
|
||||
inputs = np.tanh(outputs)
|
||||
outputs = jnp.dot(inputs, W) + b
|
||||
inputs = jnp.tanh(outputs)
|
||||
return outputs
|
||||
|
||||
def logprob_fun(params, inputs, targets):
|
||||
preds = predict(params, inputs)
|
||||
return np.sum((preds - targets)**2)
|
||||
return jnp.sum((preds - targets)**2)
|
||||
|
||||
grad_fun = jit(grad(logprob_fun)) # compiled gradient evaluation function
|
||||
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grads
|
||||
@ -114,10 +114,10 @@ for reverse-mode gradients:
|
||||
|
||||
```python
|
||||
from jax import grad
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
def tanh(x): # Define a function
|
||||
y = np.exp(-2.0 * x)
|
||||
y = jnp.exp(-2.0 * x)
|
||||
return (1.0 - y) / (1.0 + y)
|
||||
|
||||
grad_tanh = grad(tanh) # Obtain its gradient function
|
||||
@ -176,14 +176,14 @@ You can use XLA to compile your functions end-to-end with
|
||||
used either as an `@jit` decorator or as a higher-order function.
|
||||
|
||||
```python
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
|
||||
def slow_f(x):
|
||||
# Element-wise ops see a large benefit from fusion
|
||||
return x * x + x * 2.0
|
||||
|
||||
x = np.ones((5000, 5000))
|
||||
x = jnp.ones((5000, 5000))
|
||||
fast_f = jit(slow_f)
|
||||
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
|
||||
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
|
||||
@ -213,19 +213,19 @@ function:
|
||||
def predict(params, input_vec):
|
||||
assert input_vec.ndim == 1
|
||||
for W, b in params:
|
||||
output_vec = np.dot(W, input_vec) + b # `input_vec` on the right-hand side!
|
||||
input_vec = np.tanh(output_vec)
|
||||
output_vec = jnp.dot(W, input_vec) + b # `input_vec` on the right-hand side!
|
||||
input_vec = jnp.tanh(output_vec)
|
||||
return output_vec
|
||||
```
|
||||
|
||||
We often instead write `np.dot(inputs, W)` to allow for a batch dimension on the
|
||||
We often instead write `jnp.dot(inputs, W)` to allow for a batch dimension on the
|
||||
left side of `inputs`, but we’ve written this particular prediction function to
|
||||
apply only to single input vectors. If we wanted to apply this function to a
|
||||
batch of inputs at once, semantically we could just write
|
||||
|
||||
```python
|
||||
from functools import partial
|
||||
predictions = np.stack(list(map(partial(predict, params), input_batch)))
|
||||
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
|
||||
```
|
||||
|
||||
But pushing one example through the network at a time would be slow! It’s better
|
||||
@ -273,17 +273,17 @@ Here's an example on an 8-GPU machine:
|
||||
|
||||
```python
|
||||
from jax import random, pmap
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
# Create 8 random 5000 x 6000 matrices, one per GPU
|
||||
keys = random.split(random.PRNGKey(0), 8)
|
||||
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
|
||||
|
||||
# Run a local matmul on each device in parallel (no data transfer)
|
||||
result = pmap(lambda x: np.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
|
||||
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
|
||||
|
||||
# Compute the mean on each device in parallel and print the result
|
||||
print(pmap(np.mean)(result))
|
||||
print(pmap(jnp.mean)(result))
|
||||
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
|
||||
```
|
||||
|
||||
@ -299,7 +299,7 @@ from jax import lax
|
||||
def normalize(x):
|
||||
return x / lax.psum(x, 'i')
|
||||
|
||||
print(normalize(np.arange(4.)))
|
||||
print(normalize(jnp.arange(4.)))
|
||||
# prints [0. 0.16666667 0.33333334 0.5 ]
|
||||
```
|
||||
|
||||
@ -313,11 +313,11 @@ from jax import grad
|
||||
|
||||
@pmap
|
||||
def f(x):
|
||||
y = np.sin(x)
|
||||
y = jnp.sin(x)
|
||||
@pmap
|
||||
def g(z):
|
||||
return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum()
|
||||
return grad(lambda w: np.sum(g(w)))(x)
|
||||
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
|
||||
return grad(lambda w: jnp.sum(g(w)))(x)
|
||||
|
||||
print(f(x))
|
||||
# [[ 0. , -0.7170853 ],
|
||||
@ -325,7 +325,7 @@ print(f(x))
|
||||
# [10.366636 , 13.135289 ],
|
||||
# [ 0.22163185, -0.52112055]]
|
||||
|
||||
print(grad(lambda x: np.sum(f(x)))(x))
|
||||
print(grad(lambda x: jnp.sum(f(x)))(x))
|
||||
# [[ -3.2369726, -1.6356447],
|
||||
# [ 4.7572474, 11.606951 ],
|
||||
# [-98.524414 , 42.76499 ],
|
||||
|
@ -59,7 +59,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import random\n",
|
||||
"\n",
|
||||
"key = random.PRNGKey(0)"
|
||||
@ -92,7 +92,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y = np.dot(x, x)\n",
|
||||
"y = jnp.dot(x, x)\n",
|
||||
"print(y[0, 0])"
|
||||
]
|
||||
},
|
||||
@ -134,7 +134,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.dot(x, x.T)"
|
||||
"jnp.dot(x, x.T)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -147,7 +147,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(np.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])"
|
||||
"print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -160,10 +160,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"x_cpu = onp.array(x)\n",
|
||||
"%timeit -n 1 -r 1 onp.dot(x_cpu, x_cpu)"
|
||||
"x_cpu = np.array(x)\n",
|
||||
"%timeit -n 1 -r 1 np.dot(x_cpu, x_cpu)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -176,7 +176,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%timeit -n 5 -r 5 np.dot(x, x).block_until_ready()"
|
||||
"%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -262,14 +262,14 @@
|
||||
"source": [
|
||||
"def predict(params, inputs):\n",
|
||||
" for W, b in params:\n",
|
||||
" outputs = np.dot(inputs, W) + b\n",
|
||||
" inputs = np.tanh(outputs)\n",
|
||||
" outputs = jnp.dot(inputs, W) + b\n",
|
||||
" inputs = jnp.tanh(outputs)\n",
|
||||
" return outputs\n",
|
||||
"\n",
|
||||
"def loss(params, batch):\n",
|
||||
" inputs, targets = batch\n",
|
||||
" predictions = predict(params, inputs)\n",
|
||||
" return np.sum((predictions - targets)**2)\n",
|
||||
" return jnp.sum((predictions - targets)**2)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -459,7 +459,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"grad(jit(grad(jit(grad(np.tanh)))))(1.0)"
|
||||
"grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -551,7 +551,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"f(np.array([1., 2., 3.]), 5)"
|
||||
"f(jnp.array([1., 2., 3.]), 5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -565,7 +565,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"try:\n",
|
||||
" g(np.array([1., 2., 3.]), 5)\n",
|
||||
" g(jnp.array([1., 2., 3.]), 5)\n",
|
||||
"except Exception as e:\n",
|
||||
" print(e)\n",
|
||||
" pass"
|
||||
@ -594,7 +594,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"g(np.array([1., 2., 3.]), 5)"
|
||||
"g(jnp.array([1., 2., 3.]), 5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -630,7 +630,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(vmap(lambda x: x**2)(np.arange(8)))"
|
||||
"print(vmap(lambda x: x**2)(jnp.arange(8)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -645,7 +645,7 @@
|
||||
"source": [
|
||||
"from jax import make_jaxpr\n",
|
||||
"\n",
|
||||
"make_jaxpr(np.dot)(np.ones(8), np.ones(8))"
|
||||
"make_jaxpr(jnp.dot)(jnp.ones(8), jnp.ones(8))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -658,7 +658,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"make_jaxpr(vmap(np.dot))(np.ones((10, 8)), np.ones((10, 8)))"
|
||||
"make_jaxpr(vmap(jnp.dot))(jnp.ones((10, 8)), jnp.ones((10, 8)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -671,7 +671,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"make_jaxpr(vmap(vmap(np.dot)))(np.ones((10, 10, 8)), np.ones((10, 10, 8)))"
|
||||
"make_jaxpr(vmap(vmap(jnp.dot)))(jnp.ones((10, 10, 8)), jnp.ones((10, 10, 8)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -734,7 +734,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y = pmap(lambda x: x ** 2)(np.arange(8))\n",
|
||||
"y = pmap(lambda x: x ** 2)(jnp.arange(8))\n",
|
||||
"print(y)"
|
||||
]
|
||||
},
|
||||
@ -791,8 +791,8 @@
|
||||
"source": [
|
||||
"keys = random.split(random.PRNGKey(0), 8)\n",
|
||||
"mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n",
|
||||
"result = pmap(np.dot)(mats, mats)\n",
|
||||
"print(pmap(np.mean)(result))"
|
||||
"result = pmap(jnp.dot)(mats, mats)\n",
|
||||
"print(pmap(jnp.mean)(result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -805,7 +805,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"timeit -n 5 -r 5 pmap(np.dot)(mats, mats).block_until_ready()"
|
||||
"timeit -n 5 -r 5 pmap(jnp.dot)(mats, mats).block_until_ready()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -835,7 +835,7 @@
|
||||
"def normalize(x):\n",
|
||||
" return x / psum(x, 'i')\n",
|
||||
"\n",
|
||||
"print(normalize(np.arange(8.)))"
|
||||
"print(normalize(jnp.arange(8.)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -856,7 +856,7 @@
|
||||
" total_sum = psum(x, ('rows', 'cols'))\n",
|
||||
" return row_sum, col_sum, total_sum\n",
|
||||
"\n",
|
||||
"x = np.arange(8.).reshape((4, 2))\n",
|
||||
"x = jnp.arange(8.).reshape((4, 2))\n",
|
||||
"a, b, c = f(x)\n",
|
||||
"\n",
|
||||
"print(\"input:\\n\", x)\n",
|
||||
@ -907,11 +907,11 @@
|
||||
"source": [
|
||||
"@pmap\n",
|
||||
"def f(x):\n",
|
||||
" y = np.sin(x)\n",
|
||||
" y = jnp.sin(x)\n",
|
||||
" @pmap\n",
|
||||
" def g(z):\n",
|
||||
" return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum()\n",
|
||||
" return grad(lambda w: np.sum(g(w)))(x)\n",
|
||||
" return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n",
|
||||
" return grad(lambda w: jnp.sum(g(w)))(x)\n",
|
||||
" \n",
|
||||
"f(x)"
|
||||
]
|
||||
@ -926,7 +926,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"grad(lambda x: np.sum(f(x)))(x)"
|
||||
"grad(lambda x: jnp.sum(f(x)))(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -956,16 +956,16 @@
|
||||
"@curry\n",
|
||||
"def jacfwd(fun, x):\n",
|
||||
" pushfwd = partial(jvp, fun, (x,)) # jvp!\n",
|
||||
" std_basis = np.eye(onp.size(x)).reshape((-1,) + np.shape(x)),\n",
|
||||
" std_basis = jnp.eye(np.size(x)).reshape((-1,) + jnp.shape(x)),\n",
|
||||
" y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis) # vmap!\n",
|
||||
" return jac_flat.reshape(np.shape(y) + np.shape(x))\n",
|
||||
" return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))\n",
|
||||
"\n",
|
||||
"@curry\n",
|
||||
"def jacrev(fun, x):\n",
|
||||
" y, pullback = vjp(fun, x) # vjp!\n",
|
||||
" std_basis = np.eye(onp.size(y)).reshape((-1,) + np.shape(y))\n",
|
||||
" std_basis = jnp.eye(np.size(y)).reshape((-1,) + jnp.shape(y))\n",
|
||||
" jac_flat, = vmap(pullback)(std_basis) # vmap!\n",
|
||||
" return jac_flat.reshape(np.shape(y) + np.shape(x))\n",
|
||||
" return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))\n",
|
||||
"\n",
|
||||
"def hessian(fun):\n",
|
||||
" return jit(jacfwd(jacrev(fun))) # jit!"
|
||||
|
@ -80,9 +80,9 @@
|
||||
"source": [
|
||||
"import io\n",
|
||||
"from functools import partial\n",
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import vmap, jit, grad, ops, lax, config\n",
|
||||
"from jax import random as jr\n",
|
||||
"\n",
|
||||
@ -130,15 +130,15 @@
|
||||
" https://en.wikipedia.org/wiki/Xiaolin_Wu's_line_algorithm\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" ipart = lambda x: np.floor(x).astype('int32')\n",
|
||||
" ipart = lambda x: jnp.floor(x).astype('int32')\n",
|
||||
" round_ = lambda x: ipart(x + 0.5).astype('int32')\n",
|
||||
" fpart = lambda x: x - np.floor(x)\n",
|
||||
" fpart = lambda x: x - jnp.floor(x)\n",
|
||||
" rfpart = lambda x: 1 - fpart(x)\n",
|
||||
"\n",
|
||||
" def plot(im, x, y, c):\n",
|
||||
" return ops.index_add(im, ops.index[x, y], c)\n",
|
||||
"\n",
|
||||
" steep = np.abs(y1 - y0) > np.abs(x1 - x0)\n",
|
||||
" steep = jnp.abs(y1 - y0) > jnp.abs(x1 - x0)\n",
|
||||
" cond_swap = lambda cond, x: lax.cond(cond, x, lambda x: (x[1], x[0]), x, lambda x: x)\n",
|
||||
" \n",
|
||||
" (x0, y0) = cond_swap(steep, (x0, y0))\n",
|
||||
@ -149,7 +149,7 @@
|
||||
"\n",
|
||||
" dx = x1 - x0\n",
|
||||
" dy = y1 - y0\n",
|
||||
" gradient = np.where(dx == 0.0, 1.0, dy/dx)\n",
|
||||
" gradient = jnp.where(dx == 0.0, 1.0, dy/dx)\n",
|
||||
"\n",
|
||||
" # handle first endpoint\n",
|
||||
" xend = round_(x0)\n",
|
||||
@ -211,12 +211,12 @@
|
||||
" return im\n",
|
||||
"\n",
|
||||
"def img_adjust(data):\n",
|
||||
" oim = onp.array(data)\n",
|
||||
" hist, bin_edges = onp.histogram(oim.flat, bins=256*256)\n",
|
||||
" oim = np.array(data)\n",
|
||||
" hist, bin_edges = np.histogram(oim.flat, bins=256*256)\n",
|
||||
" bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2\n",
|
||||
" cdf = hist.cumsum()\n",
|
||||
" cdf = cdf / float(cdf[-1])\n",
|
||||
" return onp.interp(oim.flat, bin_centers, cdf).reshape(oim.shape)\n",
|
||||
" return np.interp(oim.flat, bin_centers, cdf).reshape(oim.shape)\n",
|
||||
"\n",
|
||||
"def imify(arr, vmin=None, vmax=None, cmap=None, origin=None):\n",
|
||||
" arr = img_adjust(arr)\n",
|
||||
@ -239,16 +239,16 @@
|
||||
" display_png(dat, raw=True)\n",
|
||||
"\n",
|
||||
"def pack_images(images, rows, cols):\n",
|
||||
" shape = onp.shape(images)\n",
|
||||
" shape = np.shape(images)\n",
|
||||
" width, height, depth = shape[-3:]\n",
|
||||
" images = onp.reshape(images, (-1, width, height, depth))\n",
|
||||
" batch = onp.shape(images)[0]\n",
|
||||
" rows = onp.minimum(rows, batch)\n",
|
||||
" cols = onp.minimum(batch // rows, cols)\n",
|
||||
" images = np.reshape(images, (-1, width, height, depth))\n",
|
||||
" batch = np.shape(images)[0]\n",
|
||||
" rows = np.minimum(rows, batch)\n",
|
||||
" cols = np.minimum(batch // rows, cols)\n",
|
||||
" images = images[:rows * cols]\n",
|
||||
" images = onp.reshape(images, (rows, cols, width, height, depth))\n",
|
||||
" images = onp.transpose(images, [0, 2, 1, 3, 4])\n",
|
||||
" images = onp.reshape(images, [rows * width, cols * height, depth])\n",
|
||||
" images = np.reshape(images, (rows, cols, width, height, depth))\n",
|
||||
" images = np.transpose(images, [0, 2, 1, 3, 4])\n",
|
||||
" images = np.reshape(images, [rows * width, cols * height, depth])\n",
|
||||
" return images"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -281,7 +281,7 @@
|
||||
"@jit\n",
|
||||
"def f(state, t):\n",
|
||||
" x, y, z = state\n",
|
||||
" return np.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])"
|
||||
" return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -343,8 +343,8 @@
|
||||
"N = 40000\n",
|
||||
"\n",
|
||||
"# set initial condition\n",
|
||||
"state0 = np.array([1., 1., 1.])\n",
|
||||
"ys = np.zeros((N,) + state0.shape)\n",
|
||||
"state0 = jnp.array([1., 1., 1.])\n",
|
||||
"ys = jnp.zeros((N,) + state0.shape)\n",
|
||||
"ys = ops.index_update(ys, ops.index[0], state0)\n",
|
||||
"\n",
|
||||
"# solve for N steps\n",
|
||||
@ -368,7 +368,7 @@
|
||||
"# fast, jitted plotting function\n",
|
||||
"@partial(jax.jit, static_argnums=(2,3,4,5))\n",
|
||||
"def jplotter(xs, zs, xlim, zlim, xN, zN):\n",
|
||||
" im = np.zeros((xN, zN))\n",
|
||||
" im = jnp.zeros((xN, zN))\n",
|
||||
" xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN\n",
|
||||
" zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN\n",
|
||||
" def body_fun(i, im):\n",
|
||||
@ -403,17 +403,17 @@
|
||||
"N = 4000\n",
|
||||
"\n",
|
||||
"# set some initial conditions for each replicate\n",
|
||||
"ys = np.zeros((N_dev, N, 3))\n",
|
||||
"ys = jnp.zeros((N_dev, N, 3))\n",
|
||||
"state0 = jr.uniform(jr.PRNGKey(1), \n",
|
||||
" minval=-1., maxval=1.,\n",
|
||||
" shape=(N_dev, 3))\n",
|
||||
"state0 = state0 * np.array([18,18,1]) + np.array((0.,0.,10.))\n",
|
||||
"state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))\n",
|
||||
"ys = ops.index_update(ys, ops.index[:, 0], state0)\n",
|
||||
"\n",
|
||||
"# solve each replicate in parallel using `pmap` of rk4 solver:\n",
|
||||
"ys = jax.pmap(rk4)(ys, \n",
|
||||
" 0.004 * np.ones(N_dev), \n",
|
||||
" N * np.ones(N_dev, dtype=onp.int32)\n",
|
||||
" 0.004 * jnp.ones(N_dev), \n",
|
||||
" N * jnp.ones(N_dev, dtype=np.int32)\n",
|
||||
" ).block_until_ready()"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -430,7 +430,7 @@
|
||||
"# parallel plotter using lexical closure and pmap'd core plotting function\n",
|
||||
"def pplotter(_xs, _zs, xlim, zlim, xN, zN):\n",
|
||||
" N_dev = _xs.shape[0]\n",
|
||||
" im = np.zeros((N_dev, xN, zN))\n",
|
||||
" im = jnp.zeros((N_dev, xN, zN))\n",
|
||||
" @jax.pmap\n",
|
||||
" def plotfn(im, xs, zs):\n",
|
||||
" xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN\n",
|
||||
@ -459,7 +459,7 @@
|
||||
"plot_image(im[:,::-1].T, cmap='magma')\n",
|
||||
"# below, plot combined ODE traces\n",
|
||||
"ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN*4, zN*4)\n",
|
||||
"plot_image(np.sum(ims, axis=0)[:,::-1].T, cmap='magma')"
|
||||
"plot_image(jnp.sum(ims, axis=0)[:,::-1].T, cmap='magma')"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
|
@ -81,7 +81,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import jax.numpy as np"
|
||||
"import jax.numpy as jnp"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -137,7 +137,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"result = pmap(lambda x: x ** 2)(np.arange(7))\n",
|
||||
"result = pmap(lambda x: x ** 2)(jnp.arange(7))\n",
|
||||
"print(result)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -163,11 +163,11 @@
|
||||
"source": [
|
||||
"from jax import vmap\n",
|
||||
"\n",
|
||||
"x = np.array([1., 2., 3.])\n",
|
||||
"y = np.array([2., 4., 6.])\n",
|
||||
"x = jnp.array([1., 2., 3.])\n",
|
||||
"y = jnp.array([2., 4., 6.])\n",
|
||||
"\n",
|
||||
"print(vmap(np.add)(x, y))\n",
|
||||
"print(pmap(np.add)(x, y))"
|
||||
"print(vmap(jnp.add)(x, y))\n",
|
||||
"print(pmap(jnp.add)(x, y))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -193,12 +193,12 @@
|
||||
"from jax import make_jaxpr\n",
|
||||
"\n",
|
||||
"def f(x, y):\n",
|
||||
" a = np.dot(x, y)\n",
|
||||
" b = np.tanh(a)\n",
|
||||
" a = jnp.dot(x, y)\n",
|
||||
" b = jnp.tanh(a)\n",
|
||||
" return b\n",
|
||||
"\n",
|
||||
"xs = np.ones((8, 2, 3))\n",
|
||||
"ys = np.ones((8, 3, 4))\n",
|
||||
"xs = jnp.ones((8, 2, 3))\n",
|
||||
"ys = jnp.ones((8, 3, 4))\n",
|
||||
"\n",
|
||||
"print(\"f jaxpr\")\n",
|
||||
"print(make_jaxpr(f)(xs[0], ys[0]))\n",
|
||||
@ -236,7 +236,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"y = pmap(lambda x: x ** 2)(np.arange(8))\n",
|
||||
"y = pmap(lambda x: x ** 2)(jnp.arange(8))\n",
|
||||
"z = y / 2\n",
|
||||
"print(z)"
|
||||
],
|
||||
@ -313,8 +313,8 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import numpy as onp\n",
|
||||
"onp.sin(y)"
|
||||
"import numpy as np\n",
|
||||
"np.sin(y)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -360,7 +360,7 @@
|
||||
},
|
||||
"source": [
|
||||
"# run a local matmul on each device in parallel (no data transfer)\n",
|
||||
"result = pmap(lambda x: np.dot(x, x.T))(mats)\n",
|
||||
"result = pmap(lambda x: jnp.dot(x, x.T))(mats)\n",
|
||||
"result.shape"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -375,7 +375,7 @@
|
||||
},
|
||||
"source": [
|
||||
"# compute the mean on each device in parallel and print the results\n",
|
||||
"print(pmap(np.mean)(result))"
|
||||
"print(pmap(jnp.mean)(result))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -423,7 +423,7 @@
|
||||
"from jax import lax\n",
|
||||
"\n",
|
||||
"normalize = lambda x: x / lax.psum(x, axis_name='i')\n",
|
||||
"result = pmap(normalize, axis_name='i')(np.arange(4.))\n",
|
||||
"result = pmap(normalize, axis_name='i')(jnp.arange(4.))\n",
|
||||
"print(result)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -455,7 +455,7 @@
|
||||
"def normalize(x):\n",
|
||||
" return x / lax.psum(x, 'i')\n",
|
||||
"\n",
|
||||
"print(normalize(np.arange(4.)))"
|
||||
"print(normalize(jnp.arange(4.)))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -486,7 +486,7 @@
|
||||
" doubly_normed = x / lax.psum(x, ('rows', 'cols'))\n",
|
||||
" return row_normed, col_normed, doubly_normed\n",
|
||||
"\n",
|
||||
"x = np.arange(8.).reshape((4, 2))\n",
|
||||
"x = jnp.arange(8.).reshape((4, 2))\n",
|
||||
"a, b, c = f(x)\n",
|
||||
"\n",
|
||||
"print(a)\n",
|
||||
@ -538,14 +538,14 @@
|
||||
"def step(board_slice):\n",
|
||||
" left, right = board_slice[:1], board_slice[-1:]\n",
|
||||
" right, left = send_left(left, 'i'), send_right(right, 'i')\n",
|
||||
" enlarged_board_slice = np.concatenate([left, board_slice, right])\n",
|
||||
" enlarged_board_slice = jnp.concatenate([left, board_slice, right])\n",
|
||||
" return update_board(enlarged_board_slice)\n",
|
||||
"\n",
|
||||
"def print_board(board):\n",
|
||||
" print(''.join('*' if x else ' ' for x in board.ravel()))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"board = onp.zeros(40, dtype=bool)\n",
|
||||
"board = np.zeros(40, dtype=bool)\n",
|
||||
"board[board.shape[0] // 2] = True\n",
|
||||
"reshaped_board = board.reshape((device_count, -1))\n",
|
||||
"\n",
|
||||
@ -589,11 +589,11 @@
|
||||
"\n",
|
||||
"@pmap\n",
|
||||
"def f(x):\n",
|
||||
" y = np.sin(x)\n",
|
||||
" y = jnp.sin(x)\n",
|
||||
" @pmap\n",
|
||||
" def g(z):\n",
|
||||
" return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum()\n",
|
||||
" return grad(lambda w: np.sum(g(w)))(x)\n",
|
||||
" return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n",
|
||||
" return grad(lambda w: jnp.sum(g(w)))(x)\n",
|
||||
" \n",
|
||||
"f(x)"
|
||||
],
|
||||
@ -608,7 +608,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"grad(lambda x: np.sum(f(x)))(x)"
|
||||
"grad(lambda x: jnp.sum(f(x)))(x)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
|
@ -84,8 +84,8 @@
|
||||
"from jax import jit, pmap\n",
|
||||
"from jax import lax\n",
|
||||
"from jax import tree_util\n",
|
||||
"import jax.numpy as np\n",
|
||||
"import numpy as onp\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import skimage.filters\n",
|
||||
"import proglog\n",
|
||||
@ -123,13 +123,13 @@
|
||||
"def halo_exchange_padding(array, padding=1, axis=0, axis_name='x'):\n",
|
||||
" if not padding > 0:\n",
|
||||
" raise ValueError(f'invalid padding: {padding}')\n",
|
||||
" array = np.array(array)\n",
|
||||
" array = jnp.array(array)\n",
|
||||
" if array.ndim == 0:\n",
|
||||
" return array\n",
|
||||
" left = slice_along_axis(array, slice(None, padding), axis)\n",
|
||||
" right = slice_along_axis(array, slice(-padding, None), axis)\n",
|
||||
" right, left = send_left(left, axis_name), send_right(right, axis_name)\n",
|
||||
" return np.concatenate([left, array, right], axis)\n",
|
||||
" return jnp.concatenate([left, array, right], axis)\n",
|
||||
"\n",
|
||||
"@tree_vectorize\n",
|
||||
"def halo_exchange_inplace(array, padding=1, axis=0, axis_name='x'):\n",
|
||||
@ -153,14 +153,14 @@
|
||||
" new_shape = list(array.shape)\n",
|
||||
" new_shape[split_axis] = tile_size\n",
|
||||
" new_shape.insert(split_axis, num_splits)\n",
|
||||
" return np.moveaxis(np.reshape(array, new_shape), split_axis, tile_id_axis)\n",
|
||||
" return jnp.moveaxis(jnp.reshape(array, new_shape), split_axis, tile_id_axis)\n",
|
||||
"\n",
|
||||
"def stack_with_reshape(array, *, split_axis=0, tile_id_axis=None):\n",
|
||||
" if tile_id_axis is None:\n",
|
||||
" tile_id_axis = split_axis\n",
|
||||
" array = np.moveaxis(array, tile_id_axis, split_axis)\n",
|
||||
" array = jnp.moveaxis(array, tile_id_axis, split_axis)\n",
|
||||
" new_shape = array.shape[:split_axis] + (-1,) + array.shape[split_axis+2:]\n",
|
||||
" return np.reshape(array, new_shape)\n",
|
||||
" return jnp.reshape(array, new_shape)\n",
|
||||
"\n",
|
||||
"def shard(func):\n",
|
||||
" def wrapper(state):\n",
|
||||
@ -178,7 +178,7 @@
|
||||
" sliced = slice_along_axis(array, index, axis)\n",
|
||||
" padding = [(0, 0)] * array.ndim\n",
|
||||
" padding[axis] = (-min(offset, 0), max(offset, 0))\n",
|
||||
" return np.pad(sliced, padding, mode='constant', constant_values=0)\n",
|
||||
" return jnp.pad(sliced, padding, mode='constant', constant_values=0)\n",
|
||||
"\n",
|
||||
"def laplacian(array, step=1):\n",
|
||||
" left = shift(array, +1, axis=0)\n",
|
||||
@ -204,10 +204,10 @@
|
||||
"\n",
|
||||
"# Time stepping\n",
|
||||
"\n",
|
||||
"def multi_step(state, count, dt=1/np.sqrt(2), c=1):\n",
|
||||
"def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):\n",
|
||||
" return lax.fori_loop(0, count, lambda i, s: leapfrog_step(s, dt, c), state)\n",
|
||||
"\n",
|
||||
"def multi_step_pmap(state, count, dt=1/np.sqrt(2), c=1, exchange_interval=1,\n",
|
||||
"def multi_step_pmap(state, count, dt=1/jnp.sqrt(2), c=1, exchange_interval=1,\n",
|
||||
" save_interval=1):\n",
|
||||
"\n",
|
||||
" def exchange_and_multi_step(state_padded):\n",
|
||||
@ -231,7 +231,7 @@
|
||||
" tree_util.tree_map(lambda x: x.copy_to_host_async(), state)\n",
|
||||
" results.append(state)\n",
|
||||
" results = jax.device_get(results)\n",
|
||||
" return tree_util.tree_multimap(lambda *xs: onp.stack([onp.array(x) for x in xs]), *results)\n",
|
||||
" return tree_util.tree_multimap(lambda *xs: np.stack([np.array(x) for x in xs]), *results)\n",
|
||||
"\n",
|
||||
"multi_step_jit = jax.jit(multi_step)"
|
||||
]
|
||||
@ -256,9 +256,9 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = np.linspace(0, 8, num=8*1024, endpoint=False)\n",
|
||||
"y = np.linspace(0, 1, num=1*1024, endpoint=False)\n",
|
||||
"x_mesh, y_mesh = np.meshgrid(x, y, indexing='ij')\n",
|
||||
"x = jnp.linspace(0, 8, num=8*1024, endpoint=False)\n",
|
||||
"y = jnp.linspace(0, 1, num=1*1024, endpoint=False)\n",
|
||||
"x_mesh, y_mesh = jnp.meshgrid(x, y, indexing='ij')\n",
|
||||
"\n",
|
||||
"# NOTE: smooth initial conditions are important, so we aren't exciting\n",
|
||||
"# arbitrarily high frequencies (that cannot be resolved)\n",
|
||||
@ -266,13 +266,13 @@
|
||||
" ((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) < 0.1 ** 2,\n",
|
||||
" sigma=1)\n",
|
||||
"\n",
|
||||
"# u = np.exp(-((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) / 0.1 ** 2)\n",
|
||||
"# u = jnp.exp(-((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) / 0.1 ** 2)\n",
|
||||
"\n",
|
||||
"# u = skimage.filters.gaussian(\n",
|
||||
"# (x_mesh > 1/3) & (x_mesh < 1/2) & (y_mesh > 1/3) & (y_mesh < 1/2),\n",
|
||||
"# sigma=5)\n",
|
||||
"\n",
|
||||
"v = np.zeros_like(u)\n",
|
||||
"v = jnp.zeros_like(u)\n",
|
||||
"c = 1 # could also use a 2D array matching the mesh shape"
|
||||
]
|
||||
},
|
||||
@ -445,7 +445,7 @@
|
||||
" images = []\n",
|
||||
" for frame in data:\n",
|
||||
" if vmax is None:\n",
|
||||
" this_vmax = onp.max(abs(frame))\n",
|
||||
" this_vmax = np.max(abs(frame))\n",
|
||||
" else:\n",
|
||||
" this_vmax = vmax\n",
|
||||
" norm = matplotlib.colors.Normalize(vmin=-this_vmax, vmax=this_vmax)\n",
|
||||
@ -474,7 +474,7 @@
|
||||
"source": [
|
||||
"# Show Movie\n",
|
||||
"proglog.default_bar_logger = partial(proglog.default_bar_logger, None)\n",
|
||||
"ImageSequenceClip([onp.array(im) for im in images], fps=25).ipython_display()"
|
||||
"ImageSequenceClip([np.array(im) for im in images], fps=25).ipython_display()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -4,11 +4,11 @@ Asynchronous dispatch
|
||||
JAX uses asynchronous dispatch to hide Python overheads. Consider the following
|
||||
program:
|
||||
|
||||
>>> import numpy as onp
|
||||
>>> from jax import numpy as np
|
||||
>>> import numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax import random
|
||||
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
|
||||
>>> np.dot(x, x) + 3. # doctest: +SKIP
|
||||
>>> jnp.dot(x, x) + 3. # doctest: +SKIP
|
||||
DeviceArray([[258.01971436, 249.64862061, 257.13372803, ...,
|
||||
236.67948914, 250.68939209, 241.36853027],
|
||||
[265.65979004, 256.28912354, 262.18252563, ...,
|
||||
@ -23,7 +23,7 @@ DeviceArray([[258.01971436, 249.64862061, 257.13372803, ...,
|
||||
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
|
||||
248.62597656, 243.22348022]], dtype=float32)
|
||||
|
||||
When an operation such as :code:`np.dot(x, x)` is executed, JAX does not wait
|
||||
When an operation such as :code:`jnp.dot(x, x)` is executed, JAX does not wait
|
||||
for the operation to complete before returning control to the Python program.
|
||||
Instead, JAX returns a :class:`~jax.DeviceArray` value, which is a future,
|
||||
i.e., a value that will be produced in the future on an accelerator device but
|
||||
@ -44,7 +44,7 @@ arbitrary amounts of work and avoid having the accelerator wait.
|
||||
|
||||
Asynchronous dispatch has a slightly surprising consequence for microbenchmarks.
|
||||
|
||||
>>> %time np.dot(x, x) # doctest: +SKIP
|
||||
>>> %time jnp.dot(x, x) # doctest: +SKIP
|
||||
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
|
||||
Wall time: 269 µs
|
||||
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
|
||||
@ -70,7 +70,7 @@ use the :meth:`~jaxDeviceArray.block_until_ready` method on a
|
||||
:class:`DeviceArray` value to wait for the computation that produced it to
|
||||
complete.
|
||||
|
||||
>>> %time onp.asarray(np.dot(x, x)) # doctest: +SKIP
|
||||
>>> %time np.asarray(jnp.dot(x, x)) # doctest: +SKIP
|
||||
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
|
||||
Wall time: 8.09 ms
|
||||
Out[16]:
|
||||
@ -87,7 +87,7 @@ array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
|
||||
258.337 ],
|
||||
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
|
||||
240.22348]], dtype=float32)
|
||||
>>> %time np.dot(x, x).block_until_ready() # doctest: +SKIP
|
||||
>>> %time jnp.dot(x, x).block_until_ready() # doctest: +SKIP
|
||||
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
|
||||
Wall time: 4.92 ms
|
||||
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
|
||||
|
@ -86,7 +86,7 @@ The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
|
||||
For example, here is the jaxpr produced for the function ``func1`` below
|
||||
|
||||
>>> from jax import make_jaxpr
|
||||
>>> from jax import numpy as jnp
|
||||
>>> import jax.numpy as jnp
|
||||
>>> def func1(first, second):
|
||||
... temp = first + jnp.sin(second) * 3.
|
||||
... return jnp.sum(temp)
|
||||
@ -187,7 +187,7 @@ JAX produces the following jaxpr
|
||||
in (e,) }
|
||||
|
||||
When tracing ``func6``, the function ``func5`` is invoked with a constant value
|
||||
(``onp.ones(8)``) for the second argument. As a result, the sub-expression
|
||||
(``np.ones(8)``) for the second argument. As a result, the sub-expression
|
||||
``jnp.sin(second) * 3.`` is constant-folded.
|
||||
There are two ConstVars, ``b`` (standing for ``jnp.sin(second) * 3.``) and ``d``
|
||||
(standing for ``jnp.ones(8)``). Unfortunately, it is not easy to tell from the
|
||||
@ -348,7 +348,7 @@ and :py:func:`jax.lax.fori_loop`
|
||||
In the above signature, “C” stands for the type of a the loop “carry” value.
|
||||
For example, here is an example fori loop
|
||||
|
||||
>>> import numpy as onp
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> def func10(arg, n):
|
||||
... ones = jnp.ones(arg.shape) # A constant
|
||||
@ -356,7 +356,7 @@ For example, here is an example fori loop
|
||||
... lambda i, carry: carry + ones * 3. + arg,
|
||||
... arg + ones)
|
||||
...
|
||||
>>> print(make_jaxpr(func10)(onp.ones(16), 5))
|
||||
>>> print(make_jaxpr(func10)(np.ones(16), 5))
|
||||
{ lambda c d ; a b.
|
||||
let e = add a d
|
||||
_ _ f = while[ body_jaxpr={ lambda ; e g a b c.
|
||||
@ -414,7 +414,7 @@ For the example consider the function ``func11`` below
|
||||
... return (carry + ae1 * ae2 + extra, carry)
|
||||
... return lax.scan(body, 0., (arr, ones))
|
||||
...
|
||||
>>> print(make_jaxpr(func11)(onp.ones(16), 5.))
|
||||
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d e = scan[ jaxpr={ lambda ; f a b c.
|
||||
let d = mul b c
|
||||
|
@ -35,12 +35,12 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"from jax import grad, jit\n",
|
||||
"from jax import lax\n",
|
||||
"from jax import random\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import matplotlib as mpl\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"from matplotlib import rcParams\n",
|
||||
@ -117,7 +117,7 @@
|
||||
"print (\"Second call: \", jit(impure_print_side_effect)(5.))\n",
|
||||
"\n",
|
||||
"# JAX re-runs the Python function when the type or shape of the argument changes\n",
|
||||
"print (\"Third call, different type: \", jit(impure_print_side_effect)(np.array([5.])))"
|
||||
"print (\"Third call, different type: \", jit(impure_print_side_effect)(jnp.array([5.])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -157,7 +157,7 @@
|
||||
"\n",
|
||||
"# JAX re-runs the Python function when the type or shape of the argument changes\n",
|
||||
"# This will end up reading the latest value of the global\n",
|
||||
"print (\"Third call, different type: \", jit(impure_uses_globals)(np.array([4.])))"
|
||||
"print (\"Third call, different type: \", jit(impure_uses_globals)(jnp.array([4.])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -334,7 +334,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"numpy_array = onp.zeros((3,3), dtype=np.float32)\n",
|
||||
"numpy_array = np.zeros((3,3), dtype=np.float32)\n",
|
||||
"print(\"original array:\")\n",
|
||||
"print(numpy_array)\n",
|
||||
"\n",
|
||||
@ -379,7 +379,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jax_array = np.zeros((3,3), dtype=np.float32)\n",
|
||||
"jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n",
|
||||
"\n",
|
||||
"# In place update of JAX's array will yield an error!\n",
|
||||
"try:\n",
|
||||
@ -470,7 +470,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jax_array = np.zeros((3, 3))\n",
|
||||
"jax_array = jnp.zeros((3, 3))\n",
|
||||
"print(\"original array:\")\n",
|
||||
"print(jax_array)\n",
|
||||
"\n",
|
||||
@ -537,7 +537,7 @@
|
||||
],
|
||||
"source": [
|
||||
"print(\"original array:\")\n",
|
||||
"jax_array = np.ones((5, 6))\n",
|
||||
"jax_array = jnp.ones((5, 6))\n",
|
||||
"print(jax_array)\n",
|
||||
"\n",
|
||||
"new_jax_array = index_add(jax_array, index[::2, 3:], 7.)\n",
|
||||
@ -591,7 +591,7 @@
|
||||
],
|
||||
"source": [
|
||||
"try:\n",
|
||||
" onp.arange(10)[11]\n",
|
||||
" np.arange(10)[11]\n",
|
||||
"except Exception as e:\n",
|
||||
" print(\"Exception {}\".format(e))"
|
||||
]
|
||||
@ -633,14 +633,14 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"np.arange(10)[11]"
|
||||
"jnp.arange(10)[11]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that due to this behavior np.nanargmin and np.nanargmax return -1 for slices consisting of NaNs whereas Numpy would throw an error."
|
||||
"Note that due to this behavior jnp.nanargmin and jnp.nanargmax return -1 for slices consisting of NaNs whereas Numpy would throw an error."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -700,9 +700,9 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(onp.random.random())\n",
|
||||
"print(onp.random.random())\n",
|
||||
"print(onp.random.random())"
|
||||
"print(np.random.random())\n",
|
||||
"print(np.random.random())\n",
|
||||
"print(np.random.random())"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -725,8 +725,8 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"onp.random.seed(0)\n",
|
||||
"rng_state = onp.random.get_state()\n",
|
||||
"np.random.seed(0)\n",
|
||||
"rng_state = np.random.get_state()\n",
|
||||
"#print(rng_state)\n",
|
||||
"# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n",
|
||||
"# 2481403966, 4042607538, 337614300, ... 614 more numbers..., \n",
|
||||
@ -753,23 +753,23 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"_ = onp.random.uniform()\n",
|
||||
"rng_state = onp.random.get_state()\n",
|
||||
"_ = np.random.uniform()\n",
|
||||
"rng_state = np.random.get_state()\n",
|
||||
"#print(rng_state) \n",
|
||||
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
|
||||
"# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n",
|
||||
"\n",
|
||||
"# Let's exhaust the entropy in this PRNG statevector\n",
|
||||
"for i in range(311):\n",
|
||||
" _ = onp.random.uniform()\n",
|
||||
"rng_state = onp.random.get_state()\n",
|
||||
" _ = np.random.uniform()\n",
|
||||
"rng_state = np.random.get_state()\n",
|
||||
"#print(rng_state) \n",
|
||||
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
|
||||
"# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n",
|
||||
"\n",
|
||||
"# Next call iterates the RNG state for a new batch of fake \"entropy\".\n",
|
||||
"_ = onp.random.uniform()\n",
|
||||
"rng_state = onp.random.get_state()\n",
|
||||
"_ = np.random.uniform()\n",
|
||||
"rng_state = np.random.get_state()\n",
|
||||
"# print(rng_state) \n",
|
||||
"# --> ('MT19937', array([1499117434, 2949980591, 2242547484, \n",
|
||||
"# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)"
|
||||
@ -1146,7 +1146,7 @@
|
||||
" y = y + x[i]\n",
|
||||
" return y\n",
|
||||
"\n",
|
||||
"print(g(np.array([1., 2., 3.])))"
|
||||
"print(g(jnp.array([1., 2., 3.])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1206,13 +1206,13 @@
|
||||
"\n",
|
||||
"When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n",
|
||||
"\n",
|
||||
"For example, if we evaluate an `@jit` function on the array `np.array([1., 2., 3.], np.float32)`, we might want to compile code that we can reuse to evaluate the function on `np.array([4., 5., 6.], np.float32)` to save on compile time.\n",
|
||||
"For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
|
||||
"\n",
|
||||
"To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/master/jax/abstract_arrays.py), and different transformations use different abstraction levels.\n",
|
||||
"\n",
|
||||
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), np.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
|
||||
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
|
||||
"\n",
|
||||
"But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), np.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), np.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n",
|
||||
"But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n",
|
||||
"\n",
|
||||
"The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:"
|
||||
]
|
||||
@ -1295,7 +1295,7 @@
|
||||
"\n",
|
||||
"f = jit(f, static_argnums=(1,))\n",
|
||||
"\n",
|
||||
"f(np.array([2., 3., 4.]), 2)"
|
||||
"f(jnp.array([2., 3., 4.]), 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1347,7 +1347,7 @@
|
||||
],
|
||||
"source": [
|
||||
"def example_fun(length, val):\n",
|
||||
" return np.ones((length,)) * val\n",
|
||||
" return jnp.ones((length,)) * val\n",
|
||||
"# un-jit'd works fine\n",
|
||||
"print(example_fun(5, 4))\n",
|
||||
"\n",
|
||||
@ -1487,7 +1487,7 @@
|
||||
"source": [
|
||||
"from jax import lax\n",
|
||||
"\n",
|
||||
"operand = np.array([0.])\n",
|
||||
"operand = jnp.array([0.])\n",
|
||||
"lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
||||
"# --> array([1.], dtype=float32)\n",
|
||||
"lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
||||
@ -1690,10 +1690,10 @@
|
||||
],
|
||||
"source": [
|
||||
"# 2D kernel - HWIO layout\n",
|
||||
"kernel = onp.zeros((3, 3, 3, 3), dtype=np.float32)\n",
|
||||
"kernel += onp.array([[1, 1, 0],\n",
|
||||
"kernel = np.zeros((3, 3, 3, 3), dtype=jnp.float32)\n",
|
||||
"kernel += np.array([[1, 1, 0],\n",
|
||||
" [1, 0,-1],\n",
|
||||
" [0,-1,-1]])[:, :, onp.newaxis, onp.newaxis]\n",
|
||||
" [0,-1,-1]])[:, :, np.newaxis, np.newaxis]\n",
|
||||
"\n",
|
||||
"print(\"Edge Conv kernel:\")\n",
|
||||
"plt.imshow(kernel[:, :, 0, 0]);"
|
||||
@ -1744,7 +1744,7 @@
|
||||
],
|
||||
"source": [
|
||||
"# NHWC layout\n",
|
||||
"img = onp.zeros((1, 200, 198, 3), dtype=np.float32)\n",
|
||||
"img = np.zeros((1, 200, 198, 3), dtype=jnp.float32)\n",
|
||||
"for k in range(3):\n",
|
||||
" x = 30 + 60*k\n",
|
||||
" y = 20 + 60*k\n",
|
||||
@ -1811,14 +1811,14 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"out = lax.conv(np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
|
||||
" np.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor\n",
|
||||
"out = lax.conv(jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
|
||||
" jnp.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor\n",
|
||||
" (1, 1), # window strides\n",
|
||||
" 'SAME') # padding mode\n",
|
||||
"print(\"out shape: \", out.shape)\n",
|
||||
"print(\"First output channel:\")\n",
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"plt.imshow(onp.array(out)[0,0,:,:]);"
|
||||
"plt.imshow(np.array(out)[0,0,:,:]);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1857,8 +1857,8 @@
|
||||
],
|
||||
"source": [
|
||||
"out = lax.conv_with_general_padding(\n",
|
||||
" np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
|
||||
" np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor\n",
|
||||
" jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
|
||||
" jnp.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor\n",
|
||||
" (1, 1), # window strides\n",
|
||||
" ((2,2),(2,2)), # general padding 2x2\n",
|
||||
" (1,1), # lhs/image dilation\n",
|
||||
@ -1866,7 +1866,7 @@
|
||||
"print(\"out shape: \", out.shape)\n",
|
||||
"print(\"First output channel:\")\n",
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"plt.imshow(onp.array(out)[0,0,:,:]);"
|
||||
"plt.imshow(np.array(out)[0,0,:,:]);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1973,7 +1973,7 @@
|
||||
"print(\"out shape: \", out.shape)\n",
|
||||
"print(\"First output channel:\")\n",
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"plt.imshow(onp.array(out)[0,:,:,0]);"
|
||||
"plt.imshow(np.array(out)[0,:,:,0]);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2031,7 +2031,7 @@
|
||||
"print(\"out shape: \", out.shape, \"DIFFERENT from above!\")\n",
|
||||
"print(\"First output channel:\")\n",
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"plt.imshow(onp.array(out)[0,:,:,0]);"
|
||||
"plt.imshow(np.array(out)[0,:,:,0]);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2089,7 +2089,7 @@
|
||||
"print(\"out shape: \", out.shape, \" <-- half the size of above\")\n",
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"print(\"First output channel:\")\n",
|
||||
"plt.imshow(onp.array(out)[0,:,:,0]);"
|
||||
"plt.imshow(np.array(out)[0,:,:,0]);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2147,7 +2147,7 @@
|
||||
"print(\"out shape: \", out.shape)\n",
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"print(\"First output channel:\")\n",
|
||||
"plt.imshow(onp.array(out)[0,:,:,0]);"
|
||||
"plt.imshow(np.array(out)[0,:,:,0]);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2205,7 +2205,7 @@
|
||||
"print(\"out shape: \", out.shape, \"<-- larger than original!\")\n",
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"print(\"First output channel:\")\n",
|
||||
"plt.imshow(onp.array(out)[0,:,:,0]);"
|
||||
"plt.imshow(np.array(out)[0,:,:,0]);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2259,7 +2259,7 @@
|
||||
"\n",
|
||||
"# transposed conv = 180deg kernel roation plus LHS dilation\n",
|
||||
"# rotate kernel 180deg:\n",
|
||||
"kernel_rot = np.rot90(np.rot90(kernel, axes=(0,1)), axes=(0,1))\n",
|
||||
"kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))\n",
|
||||
"# need a custom output padding:\n",
|
||||
"padding = ((2, 1), (2, 1))\n",
|
||||
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
|
||||
@ -2272,7 +2272,7 @@
|
||||
"print(\"out shape: \", out.shape, \"<-- transposed_conv\")\n",
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"print(\"First output channel:\")\n",
|
||||
"plt.imshow(onp.array(out)[0,:,:,0]);"
|
||||
"plt.imshow(np.array(out)[0,:,:,0]);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2344,11 +2344,11 @@
|
||||
],
|
||||
"source": [
|
||||
"# 1D kernel - WIO layout\n",
|
||||
"kernel = onp.array([[[1, 0, -1], [-1, 0, 1]], \n",
|
||||
"kernel = np.array([[[1, 0, -1], [-1, 0, 1]], \n",
|
||||
" [[1, 1, 1], [-1, -1, -1]]], \n",
|
||||
" dtype=np.float32).transpose([2,1,0])\n",
|
||||
" dtype=jnp.float32).transpose([2,1,0])\n",
|
||||
"# 1D data - NWC layout\n",
|
||||
"data = onp.zeros((1, 200, 2), dtype=np.float32)\n",
|
||||
"data = np.zeros((1, 200, 2), dtype=jnp.float32)\n",
|
||||
"for i in range(2):\n",
|
||||
" for k in range(2):\n",
|
||||
" x = 35*i + 30 + 60*k\n",
|
||||
@ -2433,16 +2433,16 @@
|
||||
],
|
||||
"source": [
|
||||
"# Random 3D kernel - HWDIO layout\n",
|
||||
"kernel = onp.array([\n",
|
||||
"kernel = np.array([\n",
|
||||
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n",
|
||||
" [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n",
|
||||
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n",
|
||||
" dtype=np.float32)[:, :, :, onp.newaxis, onp.newaxis]\n",
|
||||
" dtype=jnp.float32)[:, :, :, np.newaxis, np.newaxis]\n",
|
||||
"\n",
|
||||
"# 3D data - NHWDC layout\n",
|
||||
"data = onp.zeros((1, 30, 30, 30, 1), dtype=np.float32)\n",
|
||||
"x, y, z = onp.mgrid[0:1:30j, 0:1:30j, 0:1:30j]\n",
|
||||
"data += (onp.sin(2*x*np.pi)*onp.cos(2*y*np.pi)*onp.cos(2*z*np.pi))[None,:,:,:,None]\n",
|
||||
"data = np.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)\n",
|
||||
"x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]\n",
|
||||
"data += (np.sin(2*x*jnp.pi)*np.cos(2*y*jnp.pi)*np.cos(2*z*jnp.pi))[None,:,:,:,None]\n",
|
||||
"\n",
|
||||
"print(\"in shapes:\", data.shape, kernel.shape)\n",
|
||||
"dn = lax.conv_dimension_numbers(data.shape, kernel.shape,\n",
|
||||
@ -2461,8 +2461,8 @@
|
||||
"# Make some simple 3d density plots:\n",
|
||||
"from mpl_toolkits.mplot3d import Axes3D\n",
|
||||
"def make_alpha(cmap):\n",
|
||||
" my_cmap = cmap(np.arange(cmap.N))\n",
|
||||
" my_cmap[:,-1] = np.linspace(0, 1, cmap.N)**3\n",
|
||||
" my_cmap = cmap(jnp.arange(cmap.N))\n",
|
||||
" my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3\n",
|
||||
" return mpl.colors.ListedColormap(my_cmap)\n",
|
||||
"my_cmap = make_alpha(plt.cm.viridis)\n",
|
||||
"fig = plt.figure()\n",
|
||||
@ -2516,13 +2516,13 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```\n",
|
||||
"In [1]: import jax.numpy as np\n",
|
||||
"In [1]: import jax.numpy as jnp\n",
|
||||
"\n",
|
||||
"In [2]: np.divide(0., 0.)\n",
|
||||
"In [2]: jnp.divide(0., 0.)\n",
|
||||
"---------------------------------------------------------------------------\n",
|
||||
"FloatingPointError Traceback (most recent call last)\n",
|
||||
"<ipython-input-2-f2e2c413b437> in <module>()\n",
|
||||
"----> 1 np.divide(0., 0.)\n",
|
||||
"----> 1 jnp.divide(0., 0.)\n",
|
||||
"\n",
|
||||
".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n",
|
||||
" 343 return floor_divide(x1, x2)\n",
|
||||
@ -2549,7 +2549,7 @@
|
||||
"\n",
|
||||
".../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)\n",
|
||||
" 103 py_val = device_buffer.to_py()\n",
|
||||
" 104 if onp.any(onp.isnan(py_val)):\n",
|
||||
" 104 if np.any(np.isnan(py_val)):\n",
|
||||
"--> 105 raise FloatingPointError(\"invalid value\")\n",
|
||||
" 106 else:\n",
|
||||
" 107 return DeviceArray(device_buffer, *result_shape)\n",
|
||||
@ -2580,9 +2580,9 @@
|
||||
" ...: return a + b * c\n",
|
||||
" ...:\n",
|
||||
"\n",
|
||||
"In [6]: x = np.array([2., 0.])\n",
|
||||
"In [6]: x = jnp.array([2., 0.])\n",
|
||||
"\n",
|
||||
"In [7]: y = np.array([3., 0.])\n",
|
||||
"In [7]: y = jnp.array([3., 0.])\n",
|
||||
"\n",
|
||||
"In [8]: f(x, y)\n",
|
||||
"Invalid value encountered in the output of a jit function. Calling the de-optimized version.\n",
|
||||
@ -2673,7 +2673,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)\n",
|
||||
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
|
||||
"x.dtype"
|
||||
]
|
||||
},
|
||||
@ -2746,8 +2746,9 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from jax import numpy as np, random\n",
|
||||
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import random\n",
|
||||
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
|
||||
"x.dtype # --> dtype('float64')"
|
||||
]
|
||||
},
|
||||
@ -2803,4 +2804,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
}
|
@ -64,19 +64,19 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import custom_jvp\n",
|
||||
"\n",
|
||||
"@custom_jvp\n",
|
||||
"def f(x, y):\n",
|
||||
" return np.sin(x) * y\n",
|
||||
" return jnp.sin(x) * y\n",
|
||||
"\n",
|
||||
"@f.defjvp\n",
|
||||
"def f_jvp(primals, tangents):\n",
|
||||
" x, y = primals\n",
|
||||
" x_dot, y_dot = tangents\n",
|
||||
" primal_out = f(x, y)\n",
|
||||
" tangent_out = np.cos(x) * x_dot * y + np.sin(x) * y_dot\n",
|
||||
" tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot\n",
|
||||
" return primal_out, tangent_out"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -128,10 +128,10 @@
|
||||
"\n",
|
||||
"@custom_jvp\n",
|
||||
"def f(x, y):\n",
|
||||
" return np.sin(x) * y\n",
|
||||
" return jnp.sin(x) * y\n",
|
||||
"\n",
|
||||
"f.defjvps(lambda x_dot, primal_out, x, y: np.cos(x) * x_dot * y,\n",
|
||||
" lambda y_dot, primal_out, x, y: np.sin(x) * y_dot)"
|
||||
"f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,\n",
|
||||
" lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -190,10 +190,10 @@
|
||||
"\n",
|
||||
"@custom_vjp\n",
|
||||
"def f(x, y):\n",
|
||||
" return np.sin(x) * y\n",
|
||||
" return jnp.sin(x) * y\n",
|
||||
"\n",
|
||||
"def f_fwd(x, y):\n",
|
||||
" return f(x, y), (np.cos(x), np.sin(x), y)\n",
|
||||
" return f(x, y), (jnp.cos(x), jnp.sin(x), y)\n",
|
||||
"\n",
|
||||
"def f_bwd(res, g):\n",
|
||||
" cos_x, sin_x, y = res\n",
|
||||
@ -279,10 +279,10 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"\n",
|
||||
"def log1pexp(x):\n",
|
||||
" return np.log(1. + np.exp(x))\n",
|
||||
" return jnp.log(1. + jnp.exp(x))\n",
|
||||
"\n",
|
||||
"log1pexp(3.)"
|
||||
],
|
||||
@ -328,7 +328,7 @@
|
||||
"\n",
|
||||
"print(jit(log1pexp)(3.))\n",
|
||||
"print(jit(grad(log1pexp))(3.))\n",
|
||||
"print(vmap(jit(grad(log1pexp)))(np.arange(3.)))"
|
||||
"print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))"
|
||||
],
|
||||
"execution_count": 8,
|
||||
"outputs": [
|
||||
@ -434,7 +434,7 @@
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and $\\infty$, respectively, which is never a good idea. That is, we're effectively evaluating `lambda x: (1 / (1 + np.exp(x))) * np.exp(x)` for large `x`, which effectively turns into `0. * np.inf`.\n",
|
||||
"Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and $\\infty$, respectively, which is never a good idea. That is, we're effectively evaluating `lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)` for large `x`, which effectively turns into `0. * jnp.inf`.\n",
|
||||
"\n",
|
||||
"Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \\frac{1}{1 + e^x}$, with no cancellation in sight.\n",
|
||||
"\n",
|
||||
@ -457,14 +457,14 @@
|
||||
"\n",
|
||||
"@custom_jvp\n",
|
||||
"def log1pexp(x):\n",
|
||||
" return np.log(1. + np.exp(x))\n",
|
||||
" return jnp.log(1. + jnp.exp(x))\n",
|
||||
"\n",
|
||||
"@log1pexp.defjvp\n",
|
||||
"def log1pexp_jvp(primals, tangents):\n",
|
||||
" x, = primals\n",
|
||||
" x_dot, = tangents\n",
|
||||
" ans = log1pexp(x)\n",
|
||||
" ans_dot = (1 - 1/(1 + np.exp(x))) * x_dot\n",
|
||||
" ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot\n",
|
||||
" return ans, ans_dot"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -509,7 +509,7 @@
|
||||
"source": [
|
||||
"print(jit(log1pexp)(3.))\n",
|
||||
"print(jit(grad(log1pexp))(3.))\n",
|
||||
"print(vmap(jit(grad(log1pexp)))(np.arange(3.)))"
|
||||
"print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))"
|
||||
],
|
||||
"execution_count": 13,
|
||||
"outputs": [
|
||||
@ -544,9 +544,9 @@
|
||||
"source": [
|
||||
"@custom_jvp\n",
|
||||
"def log1pexp(x):\n",
|
||||
" return np.log(1. + np.exp(x))\n",
|
||||
" return jnp.log(1. + jnp.exp(x))\n",
|
||||
"\n",
|
||||
"log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + np.exp(x))) * t)"
|
||||
"log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -566,7 +566,7 @@
|
||||
"print(grad(log1pexp)(100.))\n",
|
||||
"print(jit(log1pexp)(3.))\n",
|
||||
"print(jit(grad(log1pexp))(3.))\n",
|
||||
"print(vmap(jit(grad(log1pexp)))(np.arange(3.)))"
|
||||
"print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))"
|
||||
],
|
||||
"execution_count": 15,
|
||||
"outputs": [
|
||||
@ -614,7 +614,7 @@
|
||||
},
|
||||
"source": [
|
||||
"def f(x):\n",
|
||||
" return x / (1 + np.sqrt(x))"
|
||||
" return x / (1 + jnp.sqrt(x))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -676,14 +676,14 @@
|
||||
"source": [
|
||||
"@custom_jvp\n",
|
||||
"def f(x):\n",
|
||||
" return x / (1 + np.sqrt(x))\n",
|
||||
" return x / (1 + jnp.sqrt(x))\n",
|
||||
"\n",
|
||||
"@f.defjvp\n",
|
||||
"def f_jvp(primals, tangents):\n",
|
||||
" x, = primals\n",
|
||||
" x_dot, = tangents\n",
|
||||
" ans = f(x)\n",
|
||||
" ans_dot = ((np.sqrt(x) + 2) / (2 * (np.sqrt(x) + 1)**2)) * x_dot\n",
|
||||
" ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot\n",
|
||||
" return ans, ans_dot"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -734,9 +734,9 @@
|
||||
"source": [
|
||||
"@custom_jvp\n",
|
||||
"def f(x):\n",
|
||||
" return x / (1 + np.sqrt(x))\n",
|
||||
" return x / (1 + jnp.sqrt(x))\n",
|
||||
"\n",
|
||||
"f.defjvps(lambda t, ans, x: ((np.sqrt(x) + 2) / (2 * (np.sqrt(x) + 1)**2)) * t)"
|
||||
"f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -777,7 +777,7 @@
|
||||
"\n",
|
||||
"While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.\n",
|
||||
"\n",
|
||||
"For gradient clipping, we can use `np.clip` together with a `jax.custom_vjp` reverse-mode-only rule:"
|
||||
"For gradient clipping, we can use `jnp.clip` together with a `jax.custom_vjp` reverse-mode-only rule:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -799,7 +799,7 @@
|
||||
" return x, None # no residual values to save\n",
|
||||
"\n",
|
||||
"def clip_gradient_bwd(lo, hi, _, g):\n",
|
||||
" return (np.clip(g, lo, hi),)\n",
|
||||
" return (jnp.clip(g, lo, hi),)\n",
|
||||
"\n",
|
||||
"clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)"
|
||||
],
|
||||
@ -821,10 +821,10 @@
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from jax import vmap\n",
|
||||
"\n",
|
||||
"t = np.linspace(0, 10, 1000)\n",
|
||||
"t = jnp.linspace(0, 10, 1000)\n",
|
||||
"\n",
|
||||
"plt.plot(np.sin(t))\n",
|
||||
"plt.plot(vmap(grad(np.sin))(t))"
|
||||
"plt.plot(jnp.sin(t))\n",
|
||||
"plt.plot(vmap(grad(jnp.sin))(t))"
|
||||
],
|
||||
"execution_count": 23,
|
||||
"outputs": [
|
||||
@ -868,7 +868,7 @@
|
||||
"source": [
|
||||
"def clip_sin(x):\n",
|
||||
" x = clip_gradient(-0.75, 0.75, x)\n",
|
||||
" return np.sin(x)\n",
|
||||
" return jnp.sin(x)\n",
|
||||
"\n",
|
||||
"plt.plot(clip_sin(t))\n",
|
||||
"plt.plot(vmap(grad(clip_sin))(t))"
|
||||
@ -963,7 +963,7 @@
|
||||
"def fixed_point(f, a, x_guess):\n",
|
||||
" def cond_fun(carry):\n",
|
||||
" x_prev, x = carry\n",
|
||||
" return np.abs(x_prev - x) > 1e-6\n",
|
||||
" return jnp.abs(x_prev - x) > 1e-6\n",
|
||||
"\n",
|
||||
" def body_fun(carry):\n",
|
||||
" _, x = carry\n",
|
||||
@ -1049,7 +1049,7 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"print(jit(vmap(newton_sqrt))(np.array([1., 2., 3., 4.])))"
|
||||
"print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))"
|
||||
],
|
||||
"execution_count": 28,
|
||||
"outputs": [
|
||||
@ -1108,7 +1108,7 @@
|
||||
"def fixed_point(f, a, x_guess):\n",
|
||||
" def cond_fun(carry):\n",
|
||||
" x_prev, x = carry\n",
|
||||
" return np.abs(x_prev - x) > 1e-6\n",
|
||||
" return jnp.abs(x_prev - x) > 1e-6\n",
|
||||
"\n",
|
||||
" def body_fun(carry):\n",
|
||||
" _, x = carry\n",
|
||||
@ -1127,7 +1127,7 @@
|
||||
" a_bar, = vjp_a(fixed_point(partial(rev_iter, f),\n",
|
||||
" (a, x_star, x_star_bar),\n",
|
||||
" x_star_bar))\n",
|
||||
" return a_bar, np.zeros_like(x_star)\n",
|
||||
" return a_bar, jnp.zeros_like(x_star)\n",
|
||||
" \n",
|
||||
"def rev_iter(f, packed, u):\n",
|
||||
" a, x_star, x_star_bar = packed\n",
|
||||
@ -1198,7 +1198,7 @@
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"We can check our answers by differentiating `np.sqrt`, which uses a totally different implementation:"
|
||||
"We can check our answers by differentiating `jnp.sqrt`, which uses a totally different implementation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1213,8 +1213,8 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"print(grad(np.sqrt)(2.))\n",
|
||||
"print(grad(grad(np.sqrt))(2.))"
|
||||
"print(grad(jnp.sqrt)(2.))\n",
|
||||
"print(grad(grad(jnp.sqrt))(2.))"
|
||||
],
|
||||
"execution_count": 32,
|
||||
"outputs": [
|
||||
@ -1270,18 +1270,18 @@
|
||||
},
|
||||
"source": [
|
||||
"from jax import custom_jvp\n",
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"\n",
|
||||
"# f :: a -> b\n",
|
||||
"@custom_jvp\n",
|
||||
"def f(x):\n",
|
||||
" return np.sin(x)\n",
|
||||
" return jnp.sin(x)\n",
|
||||
"\n",
|
||||
"# f_jvp :: (a, T a) -> (b, T b)\n",
|
||||
"def f_jvp(primals, tangents):\n",
|
||||
" x, = primals\n",
|
||||
" t, = tangents\n",
|
||||
" return f(x), np.cos(x) * t\n",
|
||||
" return f(x), jnp.cos(x) * t\n",
|
||||
"\n",
|
||||
"f.defjvp(f_jvp)"
|
||||
],
|
||||
@ -1479,9 +1479,9 @@
|
||||
"source": [
|
||||
"@custom_jvp\n",
|
||||
"def f(x):\n",
|
||||
" return np.sin(x)\n",
|
||||
" return jnp.sin(x)\n",
|
||||
"\n",
|
||||
"f.defjvps(lambda t, ans, x: np.cos(x) * t)"
|
||||
"f.defjvps(lambda t, ans, x: jnp.cos(x) * t)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -1656,14 +1656,14 @@
|
||||
"@custom_jvp\n",
|
||||
"def f(x):\n",
|
||||
" print('called f!') # a harmless side-effect\n",
|
||||
" return np.sin(x)\n",
|
||||
" return jnp.sin(x)\n",
|
||||
"\n",
|
||||
"@f.defjvp\n",
|
||||
"def f_jvp(primals, tangents):\n",
|
||||
" print('called f_jvp!') # a harmless side-effect\n",
|
||||
" x, = primals\n",
|
||||
" t, = tangents\n",
|
||||
" return f(x), np.cos(x) * t"
|
||||
" return f(x), jnp.cos(x) * t"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -1708,7 +1708,7 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"print(vmap(f)(np.arange(3.)))\n",
|
||||
"print(vmap(f)(jnp.arange(3.)))\n",
|
||||
"print(jit(f)(3.))"
|
||||
],
|
||||
"execution_count": 46,
|
||||
@ -1860,9 +1860,9 @@
|
||||
"@custom_jvp\n",
|
||||
"def f(x):\n",
|
||||
" if x > 0:\n",
|
||||
" return np.sin(x)\n",
|
||||
" return jnp.sin(x)\n",
|
||||
" else:\n",
|
||||
" return np.cos(x)\n",
|
||||
" return jnp.cos(x)\n",
|
||||
"\n",
|
||||
"@f.defjvp\n",
|
||||
"def f_jvp(primals, tangents):\n",
|
||||
@ -1925,16 +1925,16 @@
|
||||
},
|
||||
"source": [
|
||||
"from jax import custom_vjp\n",
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"\n",
|
||||
"# f :: a -> b\n",
|
||||
"@custom_vjp\n",
|
||||
"def f(x):\n",
|
||||
" return np.sin(x)\n",
|
||||
" return jnp.sin(x)\n",
|
||||
"\n",
|
||||
"# f_fwd :: a -> (b, c)\n",
|
||||
"def f_fwd(x):\n",
|
||||
" return f(x), np.cos(x)\n",
|
||||
" return f(x), jnp.cos(x)\n",
|
||||
"\n",
|
||||
"# f_bwd :: (c, CT b) -> CT a\n",
|
||||
"def f_bwd(cos_x, y_bar):\n",
|
||||
@ -2010,10 +2010,10 @@
|
||||
"\n",
|
||||
"@custom_vjp\n",
|
||||
"def f(x, y):\n",
|
||||
" return np.sin(x) * y\n",
|
||||
" return jnp.sin(x) * y\n",
|
||||
"\n",
|
||||
"def f_fwd(x, y):\n",
|
||||
" return f(x, y), (np.cos(x), np.sin(x), y)\n",
|
||||
" return f(x, y), (jnp.cos(x), jnp.sin(x), y)\n",
|
||||
"\n",
|
||||
"def f_bwd(res, g):\n",
|
||||
" cos_x, sin_x, y = res\n",
|
||||
@ -2080,11 +2080,11 @@
|
||||
"@custom_vjp\n",
|
||||
"def f(x):\n",
|
||||
" print(\"called f!\")\n",
|
||||
" return np.sin(x)\n",
|
||||
" return jnp.sin(x)\n",
|
||||
"\n",
|
||||
"def f_fwd(x):\n",
|
||||
" print(\"called f_fwd!\")\n",
|
||||
" return f(x), np.cos(x)\n",
|
||||
" return f(x), jnp.cos(x)\n",
|
||||
"\n",
|
||||
"def f_bwd(cos_x, y_bar):\n",
|
||||
" print(\"called f_bwd!\")\n",
|
||||
@ -2304,7 +2304,7 @@
|
||||
"def foo(x):\n",
|
||||
" y = x ** 2\n",
|
||||
" y = debug(y) # insert pdb in corresponding backward pass step\n",
|
||||
" return np.sin(y)"
|
||||
" return jnp.sin(y)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -2368,7 +2368,7 @@
|
||||
"def f(pt):\n",
|
||||
" x, y = pt.x, pt.y\n",
|
||||
" return {'a': x ** 2,\n",
|
||||
" 'b': (np.sin(x), np.cos(y))}\n",
|
||||
" 'b': (jnp.sin(x), jnp.cos(y))}\n",
|
||||
"\n",
|
||||
"@f.defjvp\n",
|
||||
"def f_jvp(primals, tangents):\n",
|
||||
@ -2376,7 +2376,7 @@
|
||||
" pt_dot, = tangents\n",
|
||||
" ans = f(pt)\n",
|
||||
" ans_dot = {'a': 2 * pt.x * pt_dot.x,\n",
|
||||
" 'b': (np.cos(pt.x) * pt_dot.x, -np.sin(pt.y) * pt_dot.y)}\n",
|
||||
" 'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}\n",
|
||||
" return ans, ans_dot\n",
|
||||
"\n",
|
||||
"def fun(pt):\n",
|
||||
@ -2460,15 +2460,15 @@
|
||||
"def f(pt):\n",
|
||||
" x, y = pt.x, pt.y\n",
|
||||
" return {'a': x ** 2,\n",
|
||||
" 'b': (np.sin(x), np.cos(y))}\n",
|
||||
" 'b': (jnp.sin(x), jnp.cos(y))}\n",
|
||||
"\n",
|
||||
"def f_fwd(pt):\n",
|
||||
" return f(pt), pt\n",
|
||||
"\n",
|
||||
"def f_bwd(pt, g):\n",
|
||||
" a_bar, (b0_bar, b1_bar) = g['a'], g['b']\n",
|
||||
" x_bar = 2 * pt.x * a_bar + np.cos(pt.x) * b0_bar\n",
|
||||
" y_bar = -np.sin(pt.y) * b1_bar\n",
|
||||
" x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar\n",
|
||||
" y_bar = -jnp.sin(pt.y) * b1_bar\n",
|
||||
" return (Point(x_bar, y_bar),)\n",
|
||||
"\n",
|
||||
"f.defvjp(f_fwd, f_bwd)\n",
|
||||
|
@ -239,7 +239,7 @@
|
||||
},
|
||||
"source": [
|
||||
"import jax.numpy as jnp\n",
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"@trace(\"multiply_add_numpy\")\n",
|
||||
"def multiply_add_numpy(x, y, z):\n",
|
||||
@ -429,7 +429,7 @@
|
||||
" the concrete result of the primitive.\n",
|
||||
" \"\"\"\n",
|
||||
" # Note that we can use the original numpy, which is not JAX traceable\n",
|
||||
" return onp.add(onp.multiply(x, y), z)\n",
|
||||
" return np.add(np.multiply(x, y), z)\n",
|
||||
"\n",
|
||||
"# Now we register the primal implementation with JAX\n",
|
||||
"multiply_add_p.def_impl(multiply_add_impl)"
|
||||
@ -1154,7 +1154,7 @@
|
||||
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 340, in grad_f\n",
|
||||
" _, g = value_and_grad_f(*args, **kwargs)\n",
|
||||
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 398, in value_and_grad_f\n",
|
||||
" g = vjp_py(onp.ones((), dtype=dtype))\n",
|
||||
" g = vjp_py(np.ones((), dtype=dtype))\n",
|
||||
"NotImplementedError: Reverse-mode differentiation rule for 'multiply_add' not implemented\n"
|
||||
],
|
||||
"name": "stderr"
|
||||
@ -1473,8 +1473,8 @@
|
||||
"source": [
|
||||
"# The arguments are two vectors instead of two scalars\n",
|
||||
"with expectNotImplementedError():\n",
|
||||
" api.vmap(square_add_prim, in_axes=0, out_axes=0)(onp.array([2., 3.]),\n",
|
||||
" onp.array([10., 20.]))"
|
||||
" api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),\n",
|
||||
" np.array([10., 20.]))"
|
||||
],
|
||||
"execution_count": 22,
|
||||
"outputs": [
|
||||
@ -1500,7 +1500,7 @@
|
||||
"\n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"<ipython-input-22-70154d0e2ab6>\", line 3, in <module>\n",
|
||||
" onp.array([10., 20.]))\n",
|
||||
" np.array([10., 20.]))\n",
|
||||
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n",
|
||||
" lambda: _flatten_axes(out_tree(), out_axes))\n",
|
||||
" File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n",
|
||||
@ -1574,9 +1574,9 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"assert onp.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(\n",
|
||||
" onp.array([2., 3.]),\n",
|
||||
" onp.array([10., 20.])),\n",
|
||||
"assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(\n",
|
||||
" np.array([2., 3.]),\n",
|
||||
" np.array([10., 20.])),\n",
|
||||
" [14., 29.])"
|
||||
],
|
||||
"execution_count": 24,
|
||||
@ -1622,9 +1622,9 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"assert onp.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))\n",
|
||||
" (onp.array([2., 3.]),\n",
|
||||
" onp.array([10., 20.])),\n",
|
||||
"assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))\n",
|
||||
" (np.array([2., 3.]),\n",
|
||||
" np.array([10., 20.])),\n",
|
||||
" [14., 29.])"
|
||||
],
|
||||
"execution_count": 25,
|
||||
|
@ -47,7 +47,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import grad, jit, vmap\n",
|
||||
"from jax import random"
|
||||
]
|
||||
@ -118,17 +118,17 @@
|
||||
"from jax.scipy.special import logsumexp\n",
|
||||
"\n",
|
||||
"def relu(x):\n",
|
||||
" return np.maximum(0, x)\n",
|
||||
" return jnp.maximum(0, x)\n",
|
||||
"\n",
|
||||
"def predict(params, image):\n",
|
||||
" # per-example predictions\n",
|
||||
" activations = image\n",
|
||||
" for w, b in params[:-1]:\n",
|
||||
" outputs = np.dot(w, activations) + b\n",
|
||||
" outputs = jnp.dot(w, activations) + b\n",
|
||||
" activations = relu(outputs)\n",
|
||||
" \n",
|
||||
" final_w, final_b = params[-1]\n",
|
||||
" logits = np.dot(final_w, activations) + final_b\n",
|
||||
" logits = jnp.dot(final_w, activations) + final_b\n",
|
||||
" return logits - logsumexp(logits)"
|
||||
]
|
||||
},
|
||||
@ -262,18 +262,18 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def one_hot(x, k, dtype=np.float32):\n",
|
||||
"def one_hot(x, k, dtype=jnp.float32):\n",
|
||||
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
|
||||
" return np.array(x[:, None] == np.arange(k), dtype)\n",
|
||||
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
|
||||
" \n",
|
||||
"def accuracy(params, images, targets):\n",
|
||||
" target_class = np.argmax(targets, axis=1)\n",
|
||||
" predicted_class = np.argmax(batched_predict(params, images), axis=1)\n",
|
||||
" return np.mean(predicted_class == target_class)\n",
|
||||
" target_class = jnp.argmax(targets, axis=1)\n",
|
||||
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",
|
||||
" return jnp.mean(predicted_class == target_class)\n",
|
||||
"\n",
|
||||
"def loss(params, images, targets):\n",
|
||||
" preds = batched_predict(params, images)\n",
|
||||
" return -np.mean(preds * targets)\n",
|
||||
" return -jnp.mean(preds * targets)\n",
|
||||
"\n",
|
||||
"@jit\n",
|
||||
"def update(params, x, y):\n",
|
||||
@ -334,18 +334,18 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"from torch.utils import data\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"\n",
|
||||
"def numpy_collate(batch):\n",
|
||||
" if isinstance(batch[0], onp.ndarray):\n",
|
||||
" return onp.stack(batch)\n",
|
||||
" if isinstance(batch[0], np.ndarray):\n",
|
||||
" return np.stack(batch)\n",
|
||||
" elif isinstance(batch[0], (tuple,list)):\n",
|
||||
" transposed = zip(*batch)\n",
|
||||
" return [numpy_collate(samples) for samples in transposed]\n",
|
||||
" else:\n",
|
||||
" return onp.array(batch)\n",
|
||||
" return np.array(batch)\n",
|
||||
"\n",
|
||||
"class NumpyLoader(data.DataLoader):\n",
|
||||
" def __init__(self, dataset, batch_size=1,\n",
|
||||
@ -367,7 +367,7 @@
|
||||
"\n",
|
||||
"class FlattenAndCast(object):\n",
|
||||
" def __call__(self, pic):\n",
|
||||
" return onp.ravel(onp.array(pic, dtype=np.float32))"
|
||||
" return np.ravel(np.array(pic, dtype=jnp.float32))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -526,13 +526,13 @@
|
||||
],
|
||||
"source": [
|
||||
"# Get the full train dataset (for checking accuracy while training)\n",
|
||||
"train_images = onp.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)\n",
|
||||
"train_labels = one_hot(onp.array(mnist_dataset.train_labels), n_targets)\n",
|
||||
"train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)\n",
|
||||
"train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)\n",
|
||||
"\n",
|
||||
"# Get full test dataset\n",
|
||||
"mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)\n",
|
||||
"test_images = np.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=np.float32)\n",
|
||||
"test_labels = one_hot(onp.array(mnist_dataset_test.test_labels), n_targets)"
|
||||
"test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)\n",
|
||||
"test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -33,9 +33,9 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import jit, grad, vmap\n",
|
||||
"from jax import random"
|
||||
]
|
||||
@ -84,7 +84,7 @@
|
||||
"source": [
|
||||
"x = random.normal(random.PRNGKey(0), (5000, 5000))\n",
|
||||
"def f(w, b, x):\n",
|
||||
" return np.tanh(np.dot(x, w) + b)\n",
|
||||
" return jnp.tanh(jnp.dot(x, w) + b)\n",
|
||||
"fast_f = jit(f)"
|
||||
]
|
||||
},
|
||||
@ -194,10 +194,10 @@
|
||||
"print()\n",
|
||||
"\n",
|
||||
"def bar(w, b, x):\n",
|
||||
" return np.dot(w, x) + b + np.ones(5), x\n",
|
||||
" return jnp.dot(w, x) + b + jnp.ones(5), x\n",
|
||||
"print(\"bar\")\n",
|
||||
"print(\"=====\")\n",
|
||||
"examine_jaxpr(jax.make_jaxpr(bar)(np.ones((5, 10)), np.ones(5), np.ones(10)))"
|
||||
"examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -257,9 +257,9 @@
|
||||
"Goal:\n",
|
||||
"```python\n",
|
||||
"def f(x):\n",
|
||||
" return np.exp(np.tanh(x))\n",
|
||||
" return jnp.exp(jnp.tanh(x))\n",
|
||||
"f_inv = inverse(f)\n",
|
||||
"assert np.allclose(f_inv(f(1.0)), 1.0)\n",
|
||||
"assert jnp.allclose(f_inv(f(1.0)), 1.0)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The way we'll implement this is by (1) tracing `f` into a Jaxpr, then (2) interpreting the Jaxpr *backwards*. While interpreting the Jaxpr backwards, for each equation we'll look up the primitive's inverse in a table and apply it.\n",
|
||||
@ -280,7 +280,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Importing Jax functions useful for tracing/interpreting.\n",
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"from functools import wraps\n",
|
||||
"\n",
|
||||
"from jax import api_util\n",
|
||||
@ -308,7 +308,7 @@
|
||||
" def pv_like(x):\n",
|
||||
" # ShapedArrays are abstract values that carry around\n",
|
||||
" # shape and dtype information\n",
|
||||
" aval = ShapedArray(onp.shape(x), onp.result_type(x))\n",
|
||||
" aval = ShapedArray(np.shape(x), np.result_type(x))\n",
|
||||
" return pe.PartialVal.unknown(aval)\n",
|
||||
"\n",
|
||||
" @wraps(fun)\n",
|
||||
@ -366,8 +366,8 @@
|
||||
],
|
||||
"source": [
|
||||
"def f(x):\n",
|
||||
" return np.exp(np.tanh(x))\n",
|
||||
"jaxpr, consts, _ = make_jaxpr2(f)(np.ones(5))\n",
|
||||
" return jnp.exp(jnp.tanh(x))\n",
|
||||
"jaxpr, consts, _ = make_jaxpr2(f)(jnp.ones(5))\n",
|
||||
"print(jaxpr)\n",
|
||||
"print(consts)"
|
||||
]
|
||||
@ -457,8 +457,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jaxpr, consts, _ = make_jaxpr2(f)(np.ones(5))\n",
|
||||
"eval_jaxpr(jaxpr, consts, np.ones(5))"
|
||||
"jaxpr, consts, _ = make_jaxpr2(f)(jnp.ones(5))\n",
|
||||
"eval_jaxpr(jaxpr, consts, jnp.ones(5))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -521,8 +521,8 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inverse_registry[lax.exp_p] = np.log\n",
|
||||
"inverse_registry[lax.tanh_p] = np.arctanh"
|
||||
"inverse_registry[lax.exp_p] = jnp.log\n",
|
||||
"inverse_registry[lax.tanh_p] = jnp.arctanh"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -639,9 +639,9 @@
|
||||
],
|
||||
"source": [
|
||||
"def f(x):\n",
|
||||
" return np.exp(np.tanh(x))\n",
|
||||
" return jnp.exp(jnp.tanh(x))\n",
|
||||
"f_inv = inverse(f)\n",
|
||||
"assert np.allclose(f_inv(f(1.0)), 1.0)"
|
||||
"assert jnp.allclose(f_inv(f(1.0)), 1.0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -728,7 +728,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jit(vmap(grad(inverse(f))))((np.arange(5) + 1.) / 5.)"
|
||||
"jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -97,8 +97,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"# We import as onp to emphasize that we're using vanilla numpy, not jax numpy.\n",
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# We only need to import JAX's xla_client, not all of JAX.\n",
|
||||
"from jaxlib import xla_client\n",
|
||||
@ -138,12 +137,12 @@
|
||||
"def canonicalize_dtype(dtype):\n",
|
||||
" \"\"\"We restrict ourselves to 32bit types for this demo.\"\"\"\n",
|
||||
" _dtype_to_32bit_dtype = {\n",
|
||||
" str(onp.dtype('int64')): onp.dtype('int32'),\n",
|
||||
" str(onp.dtype('uint64')): onp.dtype('uint32'),\n",
|
||||
" str(onp.dtype('float64')): onp.dtype('float32'),\n",
|
||||
" str(onp.dtype('complex128')): onp.dtype('complex64'),\n",
|
||||
" str(np.dtype('int64')): np.dtype('int32'),\n",
|
||||
" str(np.dtype('uint64')): np.dtype('uint32'),\n",
|
||||
" str(np.dtype('float64')): np.dtype('float32'),\n",
|
||||
" str(np.dtype('complex128')): np.dtype('complex64'),\n",
|
||||
" }\n",
|
||||
" dtype = onp.dtype(dtype)\n",
|
||||
" dtype = np.dtype(dtype)\n",
|
||||
" return _dtype_to_32bit_dtype.get(str(dtype), dtype)\n",
|
||||
"\n",
|
||||
"def shape_of(value):\n",
|
||||
@ -151,8 +150,8 @@
|
||||
" if hasattr(value, 'shape') and hasattr(value, 'dtype'):\n",
|
||||
" return xla_client.Shape.array_shape(canonicalize_dtype(value.dtype), \n",
|
||||
" value.shape)\n",
|
||||
" elif onp.isscalar(value):\n",
|
||||
" return shape_of(onp.asarray(value))\n",
|
||||
" elif np.isscalar(value):\n",
|
||||
" return shape_of(np.asarray(value))\n",
|
||||
" elif isinstance(value, (tuple, list)):\n",
|
||||
" return xla_client.Shape.tuple_shape(tuple(shape_of(elt) for elt in value))\n",
|
||||
" else:\n",
|
||||
@ -163,8 +162,8 @@
|
||||
" if isinstance(dtype, str):\n",
|
||||
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype]\n",
|
||||
" elif isinstance(dtype, type):\n",
|
||||
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[onp.dtype(dtype).name]\n",
|
||||
" elif isinstance(dtype, onp.dtype):\n",
|
||||
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[np.dtype(dtype).name]\n",
|
||||
" elif isinstance(dtype, np.dtype):\n",
|
||||
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype.name]\n",
|
||||
" else:\n",
|
||||
" raise TypeError('Unexpected type: {}'.format(type(dtype)))"
|
||||
@ -197,7 +196,7 @@
|
||||
"c = xla_client.ComputationBuilder(\"simple_scalar\")\n",
|
||||
"\n",
|
||||
"# define a parameter shape and parameter\n",
|
||||
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), ())\n",
|
||||
"param_shape = xla_client.Shape.array_shape(np.dtype(np.float32), ())\n",
|
||||
"x = c.ParameterWithShape(param_shape)\n",
|
||||
"\n",
|
||||
"# define computation graph\n",
|
||||
@ -212,7 +211,7 @@
|
||||
"compiled_computation = computation.Compile([param_shape,])\n",
|
||||
"\n",
|
||||
"# define a host variable with above parameter shape\n",
|
||||
"host_input = onp.array(3.0, dtype=onp.float32)\n",
|
||||
"host_input = np.array(3.0, dtype=np.float32)\n",
|
||||
"\n",
|
||||
"# place host variable on device and execute\n",
|
||||
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
|
||||
@ -251,15 +250,15 @@
|
||||
"# same as above with vector type:\n",
|
||||
"\n",
|
||||
"c = xla_client.ComputationBuilder(\"simple_vector\")\n",
|
||||
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), (3,))\n",
|
||||
"param_shape = xla_client.Shape.array_shape(np.dtype(np.float32), (3,))\n",
|
||||
"x = c.ParameterWithShape(param_shape)\n",
|
||||
"\n",
|
||||
"# can also use this function to define a shape from an example:\n",
|
||||
"#x = c.ParameterFromNumpy(onp.array([0.0, 0.0, 0.0], dtype=onp.float32))\n",
|
||||
"#x = c.ParameterFromNumpy(np.array([0.0, 0.0, 0.0], dtype=np.float32))\n",
|
||||
"\n",
|
||||
"# which is the same as using our convenience function above:\n",
|
||||
"#x = c.ParameterWithShape(shape_of(onp.array([0.0, 0.0, 0.0], \n",
|
||||
"# dtype=onp.float32)))\n",
|
||||
"#x = c.ParameterWithShape(shape_of(np.array([0.0, 0.0, 0.0], \n",
|
||||
"# dtype=np.float32)))\n",
|
||||
"\n",
|
||||
"# chain steps by reference:\n",
|
||||
"y = c.Sin(x)\n",
|
||||
@ -267,7 +266,7 @@
|
||||
"computation = c.Build()\n",
|
||||
"compiled_computation = computation.Compile([param_shape,])\n",
|
||||
"\n",
|
||||
"host_input = onp.array([3.0, 4.0, 5.0], dtype=onp.float32)\n",
|
||||
"host_input = np.array([3.0, 4.0, 5.0], dtype=np.float32)\n",
|
||||
"\n",
|
||||
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
|
||||
"device_out = compiled_computation.Execute([device_input ,])\n",
|
||||
@ -322,14 +321,14 @@
|
||||
"# body computation:\n",
|
||||
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
|
||||
"x = bcb.ParameterWithShape(in_shape)\n",
|
||||
"const1 = bcb.Constant(onp.int32(1))\n",
|
||||
"const1 = bcb.Constant(np.int32(1))\n",
|
||||
"y = bcb.Sub(x, const1)\n",
|
||||
"body_computation = bcb.Build()\n",
|
||||
"\n",
|
||||
"# test computation:\n",
|
||||
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
|
||||
"x = tcb.ParameterWithShape(in_shape)\n",
|
||||
"const0 = tcb.Constant(onp.int32(0))\n",
|
||||
"const0 = tcb.Constant(np.int32(0))\n",
|
||||
"y = tcb.Gt(x, const0)\n",
|
||||
"test_computation = tcb.Build()\n",
|
||||
"\n",
|
||||
@ -342,7 +341,7 @@
|
||||
"# Now compile and execute:\n",
|
||||
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
||||
"\n",
|
||||
"host_input = onp.array(5, dtype=onp.int32)\n",
|
||||
"host_input = np.array(5, dtype=np.int32)\n",
|
||||
"\n",
|
||||
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
|
||||
"device_out = compiled_computation.Execute([device_input ,])\n",
|
||||
@ -402,7 +401,7 @@
|
||||
"x = bcb.GetTupleElement(intuple, 1)\n",
|
||||
"guard_cntr = bcb.GetTupleElement(intuple, 2)\n",
|
||||
"new_x = bcb.Sub(x, bcb.Div(bcb.Sub(bcb.Mul(x, x), y), bcb.Add(x, x)))\n",
|
||||
"result = bcb.Tuple(y, new_x, bcb.Sub(guard_cntr, bcb.Constant(onp.int32(1))))\n",
|
||||
"result = bcb.Tuple(y, new_x, bcb.Sub(guard_cntr, bcb.Constant(np.int32(1))))\n",
|
||||
"body_computation = bcb.Build()\n",
|
||||
"\n",
|
||||
"# test computation -- convergence and max iteration test\n",
|
||||
@ -413,8 +412,8 @@
|
||||
"guard_cntr = tcb.GetTupleElement(intuple, 2)\n",
|
||||
"criterion = tcb.Abs(tcb.Sub(tcb.Mul(x, x), y))\n",
|
||||
"# stop at convergence criteria or too many iterations\n",
|
||||
"test = tcb.And(tcb.Gt(criterion, tcb.Constant(onp.float32(converged_delta))), \n",
|
||||
" tcb.Gt(guard_cntr, tcb.Constant(onp.int32(0))))\n",
|
||||
"test = tcb.And(tcb.Gt(criterion, tcb.Constant(np.float32(converged_delta))), \n",
|
||||
" tcb.Gt(guard_cntr, tcb.Constant(np.int32(0))))\n",
|
||||
"test_computation = tcb.Build()\n",
|
||||
"\n",
|
||||
"# while computation:\n",
|
||||
@ -426,9 +425,9 @@
|
||||
"# Now compile and execute:\n",
|
||||
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
||||
"\n",
|
||||
"y = onp.array(Xsqr, dtype=onp.float32)\n",
|
||||
"x = onp.array(guess, dtype=onp.float32)\n",
|
||||
"maxit = onp.array(maxit, dtype=onp.int32)\n",
|
||||
"y = np.array(Xsqr, dtype=np.float32)\n",
|
||||
"x = np.array(guess, dtype=np.float32)\n",
|
||||
"maxit = np.array(maxit, dtype=np.int32)\n",
|
||||
"\n",
|
||||
"device_input = xla_client.LocalBuffer.from_pyval((y, x, maxit))\n",
|
||||
"device_out = compiled_computation.Execute([device_input ,])\n",
|
||||
@ -483,12 +482,12 @@
|
||||
"Niter = 200\n",
|
||||
"matrix_shape = (10, 10)\n",
|
||||
"in_shape = shape_of(\n",
|
||||
" (onp.zeros(matrix_shape, dtype=onp.float32), 1)\n",
|
||||
" (np.zeros(matrix_shape, dtype=np.float32), 1)\n",
|
||||
")\n",
|
||||
"# NB: in_shape is the same as the manually constructed:\n",
|
||||
"# xla_client.Shape.tuple_shape(\n",
|
||||
"# (xla_client.Shape.array_shape(onp.dtype(onp.float32), matrix_shape), \n",
|
||||
"# xla_client.Shape.array_shape(onp.dtype(onp.int32), ()))\n",
|
||||
"# (xla_client.Shape.array_shape(np.dtype(np.float32), matrix_shape), \n",
|
||||
"# xla_client.Shape.array_shape(np.dtype(np.int32), ()))\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
|
||||
@ -500,14 +499,14 @@
|
||||
"Q = bcb.GetTupleElement(QR, 0)\n",
|
||||
"R = bcb.GetTupleElement(QR, 1)\n",
|
||||
"RQ = bcb.Dot(R, Q)\n",
|
||||
"bcb.Tuple(RQ, bcb.Sub(cntr, bcb.Constant(onp.int32(1))))\n",
|
||||
"bcb.Tuple(RQ, bcb.Sub(cntr, bcb.Constant(np.int32(1))))\n",
|
||||
"body_computation = bcb.Build()\n",
|
||||
"\n",
|
||||
"# test computation -- just a for loop condition\n",
|
||||
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
|
||||
"intuple = tcb.ParameterWithShape(in_shape)\n",
|
||||
"cntr = tcb.GetTupleElement(intuple, 1)\n",
|
||||
"test = tcb.Gt(cntr, tcb.Constant(onp.int32(0)))\n",
|
||||
"test = tcb.Gt(cntr, tcb.Constant(np.int32(0)))\n",
|
||||
"test_computation = tcb.Build()\n",
|
||||
"\n",
|
||||
"# while computation:\n",
|
||||
@ -519,9 +518,9 @@
|
||||
"# Now compile and execute:\n",
|
||||
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
||||
"\n",
|
||||
"X = onp.random.random(matrix_shape).astype(onp.float32)\n",
|
||||
"X = np.random.random(matrix_shape).astype(np.float32)\n",
|
||||
"X = (X + X.T) / 2.0\n",
|
||||
"it = onp.array(Niter, dtype=onp.int32)\n",
|
||||
"it = np.array(Niter, dtype=np.int32)\n",
|
||||
"\n",
|
||||
"device_in = xla_client.LocalBuffer.from_pyval((X, it))\n",
|
||||
"device_out = compiled_computation.Execute([device_in,])\n",
|
||||
@ -532,11 +531,11 @@
|
||||
"plt.title('D')\n",
|
||||
"plt.imshow(host_out[0])\n",
|
||||
"print('sorted eigenvalues')\n",
|
||||
"print(onp.sort(eigh_vals))\n",
|
||||
"print(np.sort(eigh_vals))\n",
|
||||
"print('sorted eigenvalues from numpy')\n",
|
||||
"print(onp.sort(onp.linalg.eigh(X)[0]))\n",
|
||||
"print(np.sort(np.linalg.eigh(X)[0]))\n",
|
||||
"print('sorted error') \n",
|
||||
"print(onp.sort(eigh_vals) - onp.sort(onp.linalg.eigh(X)[0]))"
|
||||
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
@ -605,8 +604,8 @@
|
||||
"Niter = 100\n",
|
||||
"matrix_shape = (10, 10)\n",
|
||||
"in_shape = shape_of(\n",
|
||||
" (onp.zeros(matrix_shape, dtype=onp.float32), \n",
|
||||
" onp.eye(matrix_shape[0]),\n",
|
||||
" (np.zeros(matrix_shape, dtype=np.float32), \n",
|
||||
" np.eye(matrix_shape[0]),\n",
|
||||
" 1)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
@ -621,14 +620,14 @@
|
||||
"R = bcb.GetTupleElement(QR, 1)\n",
|
||||
"RQ = bcb.Dot(R, Q)\n",
|
||||
"Onew = bcb.Dot(O, Q)\n",
|
||||
"bcb.Tuple(RQ, Onew, bcb.Sub(cntr, bcb.Constant(onp.int32(1))))\n",
|
||||
"bcb.Tuple(RQ, Onew, bcb.Sub(cntr, bcb.Constant(np.int32(1))))\n",
|
||||
"body_computation = bcb.Build()\n",
|
||||
"\n",
|
||||
"# test computation -- just a for loop condition\n",
|
||||
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
|
||||
"intuple = tcb.ParameterWithShape(in_shape)\n",
|
||||
"cntr = tcb.GetTupleElement(intuple, 2)\n",
|
||||
"test = tcb.Gt(cntr, tcb.Constant(onp.int32(0)))\n",
|
||||
"test = tcb.Gt(cntr, tcb.Constant(np.int32(0)))\n",
|
||||
"test_computation = tcb.Build()\n",
|
||||
"\n",
|
||||
"# while computation:\n",
|
||||
@ -640,10 +639,10 @@
|
||||
"# Now compile and execute:\n",
|
||||
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
||||
"\n",
|
||||
"X = onp.random.random(matrix_shape).astype(onp.float32)\n",
|
||||
"X = np.random.random(matrix_shape).astype(np.float32)\n",
|
||||
"X = (X + X.T) / 2.0\n",
|
||||
"Omat = onp.eye(matrix_shape[0], dtype=onp.float32)\n",
|
||||
"it = onp.array(Niter, dtype=onp.int32)\n",
|
||||
"Omat = np.eye(matrix_shape[0], dtype=np.float32)\n",
|
||||
"it = np.array(Niter, dtype=np.int32)\n",
|
||||
"\n",
|
||||
"device_in = xla_client.LocalBuffer.from_pyval((X, Omat, it))\n",
|
||||
"device_out = compiled_computation.Execute([device_in,])\n",
|
||||
@ -659,13 +658,13 @@
|
||||
"plt.imshow(eigh_mat)\n",
|
||||
"plt.figure()\n",
|
||||
"plt.title('U^T A U')\n",
|
||||
"plt.imshow(onp.dot(onp.dot(eigh_mat.T, X), eigh_mat))\n",
|
||||
"plt.imshow(np.dot(np.dot(eigh_mat.T, X), eigh_mat))\n",
|
||||
"print('sorted eigenvalues')\n",
|
||||
"print(onp.sort(eigh_vals))\n",
|
||||
"print(np.sort(eigh_vals))\n",
|
||||
"print('sorted eigenvalues from numpy')\n",
|
||||
"print(onp.sort(onp.linalg.eigh(X)[0]))\n",
|
||||
"print(np.sort(np.linalg.eigh(X)[0]))\n",
|
||||
"print('sorted error') \n",
|
||||
"print(onp.sort(eigh_vals) - onp.sort(onp.linalg.eigh(X)[0]))"
|
||||
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
@ -752,7 +751,7 @@
|
||||
"Niter=13\n",
|
||||
"matrix_shape = (1,1, 20, 20)\n",
|
||||
"in_shape = shape_of(\n",
|
||||
" (onp.zeros(matrix_shape, dtype=onp.int32), 1)\n",
|
||||
" (np.zeros(matrix_shape, dtype=np.int32), 1)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Body computation -- Conway Update\n",
|
||||
@ -762,8 +761,8 @@
|
||||
"cntr = bcb.GetTupleElement(intuple, 1)\n",
|
||||
"# convs require floating-point type\n",
|
||||
"xf = bcb.ConvertElementType(x, to_xla_type('float32'))\n",
|
||||
"stamp = bcb.Constant(onp.ones((1,1,3,3), dtype=onp.float32))\n",
|
||||
"convd = bcb.Conv(xf, stamp, onp.array([1, 1]), xla_client.PaddingType.SAME)\n",
|
||||
"stamp = bcb.Constant(np.ones((1,1,3,3), dtype=np.float32))\n",
|
||||
"convd = bcb.Conv(xf, stamp, np.array([1, 1]), xla_client.PaddingType.SAME)\n",
|
||||
"# logic ops require integer types\n",
|
||||
"convd = bcb.ConvertElementType(convd, to_xla_type('int32'))\n",
|
||||
"bool_x = bcb.Eq(x, bcb.ConstantS32Scalar(1))\n",
|
||||
@ -800,13 +799,13 @@
|
||||
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
||||
"\n",
|
||||
"# Set up initial state\n",
|
||||
"X = onp.zeros(matrix_shape, dtype=onp.int32)\n",
|
||||
"X[0,0, 5:8, 5:8] = onp.array([[0,1,0],[0,0,1],[1,1,1]])\n",
|
||||
"X = np.zeros(matrix_shape, dtype=np.int32)\n",
|
||||
"X[0,0, 5:8, 5:8] = np.array([[0,1,0],[0,0,1],[1,1,1]])\n",
|
||||
"\n",
|
||||
"# Evolve\n",
|
||||
"movie = onp.zeros((Niter,)+matrix_shape[-2:], dtype=onp.int32)\n",
|
||||
"movie = np.zeros((Niter,)+matrix_shape[-2:], dtype=np.int32)\n",
|
||||
"for it in range(Niter):\n",
|
||||
" itr = onp.array(it, dtype=onp.int32)\n",
|
||||
" itr = np.array(it, dtype=np.int32)\n",
|
||||
" device_in = xla_client.LocalBuffer.from_pyval((X, itr))\n",
|
||||
" device_out = compiled_computation.Execute([device_in,])\n",
|
||||
" movie[it] = device_out.to_py()[0][0,0]\n",
|
||||
|
@ -52,7 +52,7 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import grad, jit, vmap\n",
|
||||
"from jax import random\n",
|
||||
"\n",
|
||||
@ -104,7 +104,7 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"grad_tanh = grad(np.tanh)\n",
|
||||
"grad_tanh = grad(jnp.tanh)\n",
|
||||
"print(grad_tanh(2.0))"
|
||||
],
|
||||
"execution_count": 2,
|
||||
@ -142,8 +142,8 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"print(grad(grad(np.tanh))(2.0))\n",
|
||||
"print(grad(grad(grad(np.tanh)))(2.0))"
|
||||
"print(grad(grad(jnp.tanh))(2.0))\n",
|
||||
"print(grad(grad(grad(jnp.tanh)))(2.0))"
|
||||
],
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
@ -176,24 +176,24 @@
|
||||
},
|
||||
"source": [
|
||||
"def sigmoid(x):\n",
|
||||
" return 0.5 * (np.tanh(x / 2) + 1)\n",
|
||||
" return 0.5 * (jnp.tanh(x / 2) + 1)\n",
|
||||
"\n",
|
||||
"# Outputs probability of a label being true.\n",
|
||||
"def predict(W, b, inputs):\n",
|
||||
" return sigmoid(np.dot(inputs, W) + b)\n",
|
||||
" return sigmoid(jnp.dot(inputs, W) + b)\n",
|
||||
"\n",
|
||||
"# Build a toy dataset.\n",
|
||||
"inputs = np.array([[0.52, 1.12, 0.77],\n",
|
||||
"inputs = jnp.array([[0.52, 1.12, 0.77],\n",
|
||||
" [0.88, -1.08, 0.15],\n",
|
||||
" [0.52, 0.06, -1.30],\n",
|
||||
" [0.74, -2.49, 1.39]])\n",
|
||||
"targets = np.array([True, True, False, True])\n",
|
||||
"targets = jnp.array([True, True, False, True])\n",
|
||||
"\n",
|
||||
"# Training loss is the negative log-likelihood of the training examples.\n",
|
||||
"def loss(W, b):\n",
|
||||
" preds = predict(W, b, inputs)\n",
|
||||
" label_probs = preds * targets + (1 - preds) * (1 - targets)\n",
|
||||
" return -np.sum(np.log(label_probs))\n",
|
||||
" return -jnp.sum(jnp.log(label_probs))\n",
|
||||
"\n",
|
||||
"# Initialize random model coefficients\n",
|
||||
"key, W_key, b_key = random.split(key, 3)\n",
|
||||
@ -304,7 +304,7 @@
|
||||
"def loss2(params_dict):\n",
|
||||
" preds = predict(params_dict['W'], params_dict['b'], inputs)\n",
|
||||
" label_probs = preds * targets + (1 - preds) * (1 - targets)\n",
|
||||
" return -np.sum(np.log(label_probs))\n",
|
||||
" return -jnp.sum(jnp.log(label_probs))\n",
|
||||
"\n",
|
||||
"print(grad(loss2)({'W': W, 'b': b}))"
|
||||
],
|
||||
@ -413,10 +413,10 @@
|
||||
"# Check W_grad with finite differences in a random direction\n",
|
||||
"key, subkey = random.split(key)\n",
|
||||
"vec = random.normal(subkey, W.shape)\n",
|
||||
"unitvec = vec / np.sqrt(np.vdot(vec, vec))\n",
|
||||
"unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))\n",
|
||||
"W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps\n",
|
||||
"print('W_dirderiv_numerical', W_grad_numerical)\n",
|
||||
"print('W_dirderiv_autodiff', np.vdot(grad(loss)(W, b), unitvec))"
|
||||
"print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))"
|
||||
],
|
||||
"execution_count": 8,
|
||||
"outputs": [
|
||||
@ -508,7 +508,7 @@
|
||||
},
|
||||
"source": [
|
||||
"def hvp(f, x, v):\n",
|
||||
" return grad(lambda x: np.vdot(grad(f)(x), v))(x)"
|
||||
" return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -945,7 +945,7 @@
|
||||
},
|
||||
"source": [
|
||||
"def hvp(f, x, v):\n",
|
||||
" return grad(lambda x: np.vdot(grad(f)(x), v))(x)"
|
||||
" return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -994,7 +994,7 @@
|
||||
"id": "XUsye1SwfSFm"
|
||||
},
|
||||
"source": [
|
||||
"Even better, since we didn't have to call `np.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on `jax.numpy`.\n",
|
||||
"Even better, since we didn't have to call `jnp.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on `jax.numpy`.\n",
|
||||
"\n",
|
||||
"Here's an example of how to use it:"
|
||||
]
|
||||
@ -1012,16 +1012,16 @@
|
||||
},
|
||||
"source": [
|
||||
"def f(X):\n",
|
||||
" return np.sum(np.tanh(X)**2)\n",
|
||||
" return jnp.sum(jnp.tanh(X)**2)\n",
|
||||
"\n",
|
||||
"key, subkey1, subkey2 = random.split(key, 3)\n",
|
||||
"X = random.normal(subkey1, (30, 40))\n",
|
||||
"V = random.normal(subkey2, (30, 40))\n",
|
||||
"\n",
|
||||
"ans1 = hvp(f, (X,), (V,))\n",
|
||||
"ans2 = np.tensordot(hessian(f)(X), V, 2)\n",
|
||||
"ans2 = jnp.tensordot(hessian(f)(X), V, 2)\n",
|
||||
"\n",
|
||||
"print(np.allclose(ans1, ans2, 1e-4, 1e-4))"
|
||||
"print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))"
|
||||
],
|
||||
"execution_count": 18,
|
||||
"outputs": [
|
||||
@ -1086,7 +1086,7 @@
|
||||
"def hvp_revrev(f, primals, tangents):\n",
|
||||
" x, = primals\n",
|
||||
" v, = tangents\n",
|
||||
" return grad(lambda x: np.vdot(grad(f)(x), v))(x)\n",
|
||||
" return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"Forward over reverse\")\n",
|
||||
@ -1097,7 +1097,7 @@
|
||||
"%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))\n",
|
||||
"\n",
|
||||
"print(\"Naive full Hessian materialization\")\n",
|
||||
"%timeit -n10 -r3 np.tensordot(hessian(f)(X), V, 2)"
|
||||
"%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)"
|
||||
],
|
||||
"execution_count": 20,
|
||||
"outputs": [
|
||||
@ -1158,7 +1158,7 @@
|
||||
"# First, use a list comprehension to loop over rows in the matrix M.\n",
|
||||
"def loop_mjp(f, x, M):\n",
|
||||
" y, vjp_fun = vjp(f, x)\n",
|
||||
" return np.vstack([vjp_fun(mi) for mi in M])\n",
|
||||
" return jnp.vstack([vjp_fun(mi) for mi in M])\n",
|
||||
"\n",
|
||||
"# Now, use vmap to build a computation that does a single fast matrix-matrix\n",
|
||||
"# multiply, rather than an outer loop over vector-matrix multiplies.\n",
|
||||
@ -1179,7 +1179,7 @@
|
||||
"vmap_vs = vmap_mjp(f, W, M=U)\n",
|
||||
"%timeit -n10 -r3 vmap_mjp(f, W, M=U)\n",
|
||||
"\n",
|
||||
"assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'"
|
||||
"assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'"
|
||||
],
|
||||
"execution_count": 21,
|
||||
"outputs": [
|
||||
@ -1211,7 +1211,7 @@
|
||||
"def loop_jmp(f, x, M):\n",
|
||||
" # jvp immediately returns the primal and tangent values as a tuple,\n",
|
||||
" # so we'll compute and select the tangents in a list comprehension\n",
|
||||
" return np.vstack([jvp(f, (W,), (mi,))[1] for mi in M])\n",
|
||||
" return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])\n",
|
||||
"\n",
|
||||
"def vmap_jmp(f, x, M):\n",
|
||||
" _jvp = lambda s: jvp(f, (W,), (s,))[1]\n",
|
||||
@ -1227,7 +1227,7 @@
|
||||
"print('\\nVmapped Jacobian-Matrix product')\n",
|
||||
"%timeit -n10 -r3 vmap_jmp(f, W, M=S)\n",
|
||||
"\n",
|
||||
"assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'"
|
||||
"assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'"
|
||||
],
|
||||
"execution_count": 22,
|
||||
"outputs": [
|
||||
@ -1281,11 +1281,11 @@
|
||||
" # Use vmap to do a matrix-Jacobian product.\n",
|
||||
" # Here, the matrix is the Euclidean basis, so we get all\n",
|
||||
" # entries in the Jacobian at once. \n",
|
||||
" J, = vmap(vjp_fun, in_axes=0)(np.eye(len(y)))\n",
|
||||
" J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))\n",
|
||||
" return J\n",
|
||||
" return jacfun\n",
|
||||
"\n",
|
||||
"assert np.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'"
|
||||
"assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -1303,11 +1303,11 @@
|
||||
"def our_jacfwd(f):\n",
|
||||
" def jacfun(x):\n",
|
||||
" _jvp = lambda s: jvp(f, (x,), (s,))[1]\n",
|
||||
" Jt =vmap(_jvp, in_axes=1)(np.eye(len(x)))\n",
|
||||
" return np.transpose(Jt)\n",
|
||||
" Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n",
|
||||
" return jnp.transpose(Jt)\n",
|
||||
" return jacfun\n",
|
||||
"\n",
|
||||
"assert np.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'"
|
||||
"assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -1351,7 +1351,7 @@
|
||||
" else:\n",
|
||||
" raise ValueError\n",
|
||||
" except ValueError:\n",
|
||||
" return np.pi * x\n",
|
||||
" return jnp.pi * x\n",
|
||||
"\n",
|
||||
"y, f_vjp = vjp(f, 4.)\n",
|
||||
"print(jit(f_vjp)(1.))"
|
||||
@ -1398,7 +1398,7 @@
|
||||
},
|
||||
"source": [
|
||||
"def f(z):\n",
|
||||
" x, y = np.real(z), np.imag(z)\n",
|
||||
" x, y = jnp.real(z), jnp.imag(z)\n",
|
||||
" return u(x, y) + v(x, y) * 1j\n",
|
||||
"\n",
|
||||
"def g(x, y):\n",
|
||||
@ -1465,7 +1465,7 @@
|
||||
" a, b, c, d = random.uniform(subkey, (4,))\n",
|
||||
"\n",
|
||||
" def fun(z):\n",
|
||||
" x, y = np.real(z), np.imag(z)\n",
|
||||
" x, y = jnp.real(z), jnp.imag(z)\n",
|
||||
" return u(x, y) + v(x, y) * 1j\n",
|
||||
"\n",
|
||||
" def u(x, y):\n",
|
||||
@ -1490,7 +1490,7 @@
|
||||
" grad(u, 1)(x, y) * d +\n",
|
||||
" grad(v, 0)(x, y) * c * 1j+\n",
|
||||
" grad(v, 1)(x, y) * d * 1j)\n",
|
||||
" print(np.allclose(ans, expected))"
|
||||
" print(jnp.allclose(ans, expected))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -1567,7 +1567,7 @@
|
||||
" a, b, c, d = random.uniform(subkey, (4,))\n",
|
||||
"\n",
|
||||
" def fun(z):\n",
|
||||
" x, y = np.real(z), np.imag(z)\n",
|
||||
" x, y = jnp.real(z), jnp.imag(z)\n",
|
||||
" return u(x, y) + v(x, y) * 1j\n",
|
||||
"\n",
|
||||
" def u(x, y):\n",
|
||||
@ -1584,7 +1584,7 @@
|
||||
" # cotangent vector\n",
|
||||
" key, subkey = random.split(key)\n",
|
||||
" c, d = random.uniform(subkey, (2,))\n",
|
||||
" z_bar = np.array(c + d * 1j) # for dtype control\n",
|
||||
" z_bar = jnp.array(c + d * 1j) # for dtype control\n",
|
||||
"\n",
|
||||
" # check vjp\n",
|
||||
" _, fun_vjp = vjp(fun, z)\n",
|
||||
@ -1593,7 +1593,7 @@
|
||||
" grad(v, 0)(x, y) * (-d) +\n",
|
||||
" grad(u, 1)(x, y) * c * (-1j) +\n",
|
||||
" grad(v, 1)(x, y) * (-d) * (-1j))\n",
|
||||
" assert np.allclose(ans, expected, atol=1e-5, rtol=1e-5)"
|
||||
" assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -1638,7 +1638,7 @@
|
||||
},
|
||||
"source": [
|
||||
"def f(z):\n",
|
||||
" x, y = np.real(z), np.imag(z)\n",
|
||||
" x, y = jnp.real(z), jnp.imag(z)\n",
|
||||
" return x**2 + y**2\n",
|
||||
"\n",
|
||||
"z = 3. + 4j\n",
|
||||
@ -1685,7 +1685,7 @@
|
||||
},
|
||||
"source": [
|
||||
"def f(z):\n",
|
||||
" return np.sin(z)\n",
|
||||
" return jnp.sin(z)\n",
|
||||
"\n",
|
||||
"z = 3. + 4j\n",
|
||||
"grad(f, holomorphic=True)(z)"
|
||||
@ -1729,7 +1729,7 @@
|
||||
},
|
||||
"source": [
|
||||
"def f(z):\n",
|
||||
" return np.conjugate(z)\n",
|
||||
" return jnp.conjugate(z)\n",
|
||||
"\n",
|
||||
"z = 3. + 4j\n",
|
||||
"grad(f, holomorphic=True)(z) # f is not actually holomorphic!"
|
||||
@ -1788,13 +1788,13 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"A = np.array([[5., 2.+3j, 5j],\n",
|
||||
"A = jnp.array([[5., 2.+3j, 5j],\n",
|
||||
" [2.-3j, 7., 1.+7j],\n",
|
||||
" [-5j, 1.-7j, 12.]])\n",
|
||||
"\n",
|
||||
"def f(X):\n",
|
||||
" L = np.linalg.cholesky(X)\n",
|
||||
" return np.sum((L - np.sin(L))**2)\n",
|
||||
" L = jnp.linalg.cholesky(X)\n",
|
||||
" return jnp.sum((L - jnp.sin(L))**2)\n",
|
||||
"\n",
|
||||
"grad(f, holomorphic=True)(A)"
|
||||
],
|
||||
|
@ -64,7 +64,7 @@
|
||||
},
|
||||
"source": [
|
||||
"### import jax.numpy (almost-drop-in for numpy) and gradient operators.\n",
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import grad"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -94,8 +94,8 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"f = lambda x : np.exp(x)\n",
|
||||
"g = lambda x : np.square(x)\n",
|
||||
"f = lambda x : jnp.exp(x)\n",
|
||||
"g = lambda x : jnp.square(x)\n",
|
||||
"print(grad(f)(1.)) # = e^{1}\n",
|
||||
"print(grad(grad(f))(1.))\n",
|
||||
"print(grad(grad(grad(f)))(1.))\n",
|
||||
@ -184,7 +184,7 @@
|
||||
"def loss(params, inputs, targets):\n",
|
||||
" # Computes average loss for the batch\n",
|
||||
" predictions = net_apply(params, inputs)\n",
|
||||
" return np.mean((targets - predictions)**2)"
|
||||
" return jnp.mean((targets - predictions)**2)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -202,8 +202,8 @@
|
||||
},
|
||||
"source": [
|
||||
"# batch the inference across K=100\n",
|
||||
"xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)) # (k, 1)\n",
|
||||
"targets = np.sin(xrange_inputs)\n",
|
||||
"xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)\n",
|
||||
"targets = jnp.sin(xrange_inputs)\n",
|
||||
"predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n",
|
||||
"losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss\n",
|
||||
"plt.plot(xrange_inputs, predictions, label='prediction')\n",
|
||||
@ -247,7 +247,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"from jax.experimental import optimizers\n",
|
||||
"from jax.tree_util import tree_multimap # Element-wise manipulation of collections of numpy arrays "
|
||||
],
|
||||
@ -292,7 +292,7 @@
|
||||
},
|
||||
"source": [
|
||||
"# batch the inference across K=100\n",
|
||||
"targets = np.sin(xrange_inputs)\n",
|
||||
"targets = jnp.sin(xrange_inputs)\n",
|
||||
"predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n",
|
||||
"losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss\n",
|
||||
"plt.plot(xrange_inputs, predictions, label='prediction')\n",
|
||||
@ -360,7 +360,7 @@
|
||||
"source": [
|
||||
"# gradients of gradients test for MAML\n",
|
||||
"# check numerics\n",
|
||||
"g = lambda x, y : np.square(x) + y\n",
|
||||
"g = lambda x, y : jnp.square(x) + y\n",
|
||||
"x0 = 2.\n",
|
||||
"y0 = 1.\n",
|
||||
"print('grad(g)(x0) = {}'.format(grad(g)(x0, y0))) # 2x = 4\n",
|
||||
@ -432,8 +432,8 @@
|
||||
"source": [
|
||||
"x1 = xrange_inputs\n",
|
||||
"y1 = targets\n",
|
||||
"x2 = np.array([0.])\n",
|
||||
"y2 = np.array([0.])\n",
|
||||
"x2 = jnp.array([0.])\n",
|
||||
"y2 = jnp.array([0.])\n",
|
||||
"maml_loss(net_params, x1, y1, x2, y2)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -491,14 +491,14 @@
|
||||
"# Adam optimization\n",
|
||||
"for i in range(20000):\n",
|
||||
" # define the task\n",
|
||||
" A = onp.random.uniform(low=0.1, high=.5)\n",
|
||||
" phase = onp.random.uniform(low=0., high=np.pi)\n",
|
||||
" A = np.random.uniform(low=0.1, high=.5)\n",
|
||||
" phase = np.random.uniform(low=0., high=jnp.pi)\n",
|
||||
" # meta-training inner split (K examples)\n",
|
||||
" x1 = onp.random.uniform(low=-5., high=5., size=(K,1))\n",
|
||||
" y1 = A * onp.sin(x1 + phase)\n",
|
||||
" x1 = np.random.uniform(low=-5., high=5., size=(K,1))\n",
|
||||
" y1 = A * np.sin(x1 + phase)\n",
|
||||
" # meta-training outer split (1 example). Like cross-validating with respect to one example.\n",
|
||||
" x2 = onp.random.uniform(low=-5., high=5.)\n",
|
||||
" y2 = A * onp.sin(x2 + phase)\n",
|
||||
" x2 = np.random.uniform(low=-5., high=5.)\n",
|
||||
" y2 = A * np.sin(x2 + phase)\n",
|
||||
" opt_state, l = step(i, opt_state, x1, y1, x2, y2)\n",
|
||||
" np_maml_loss.append(l)\n",
|
||||
" if i % 1000 == 0:\n",
|
||||
@ -548,13 +548,13 @@
|
||||
},
|
||||
"source": [
|
||||
"# batch the inference across K=100\n",
|
||||
"targets = np.sin(xrange_inputs)\n",
|
||||
"targets = jnp.sin(xrange_inputs)\n",
|
||||
"predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n",
|
||||
"plt.plot(xrange_inputs, predictions, label='pre-update predictions')\n",
|
||||
"plt.plot(xrange_inputs, targets, label='target')\n",
|
||||
"\n",
|
||||
"x1 = onp.random.uniform(low=-5., high=5., size=(K,1))\n",
|
||||
"y1 = 1. * onp.sin(x1 + 0.)\n",
|
||||
"x1 = np.random.uniform(low=-5., high=5., size=(K,1))\n",
|
||||
"y1 = 1. * np.sin(x1 + 0.)\n",
|
||||
"\n",
|
||||
"for i in range(1,5):\n",
|
||||
" net_params = inner_update(net_params, x1, y1)\n",
|
||||
@ -619,16 +619,16 @@
|
||||
" As = []\n",
|
||||
" phases = []\n",
|
||||
" for _ in range(outer_batch_size): \n",
|
||||
" As.append(onp.random.uniform(low=0.1, high=.5))\n",
|
||||
" phases.append(onp.random.uniform(low=0., high=np.pi))\n",
|
||||
" As.append(np.random.uniform(low=0.1, high=.5))\n",
|
||||
" phases.append(np.random.uniform(low=0., high=jnp.pi))\n",
|
||||
" def get_batch():\n",
|
||||
" xs, ys = [], []\n",
|
||||
" for A, phase in zip(As, phases):\n",
|
||||
" x = onp.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))\n",
|
||||
" y = A * onp.sin(x + phase)\n",
|
||||
" x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))\n",
|
||||
" y = A * np.sin(x + phase)\n",
|
||||
" xs.append(x)\n",
|
||||
" ys.append(y)\n",
|
||||
" return np.stack(xs), np.stack(ys)\n",
|
||||
" return jnp.stack(xs), jnp.stack(ys)\n",
|
||||
" x1, y1 = get_batch()\n",
|
||||
" x2, y2 = get_batch()\n",
|
||||
" return x1, y1, x2, y2"
|
||||
@ -734,7 +734,7 @@
|
||||
"# returns scalar for all tasks.\n",
|
||||
"def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):\n",
|
||||
" task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)\n",
|
||||
" return np.mean(task_losses)\n",
|
||||
" return jnp.mean(task_losses)\n",
|
||||
"\n",
|
||||
"@jit\n",
|
||||
"def step(i, opt_state, x1, y1, x2, y2):\n",
|
||||
@ -796,13 +796,13 @@
|
||||
},
|
||||
"source": [
|
||||
"# batch the inference across K=100\n",
|
||||
"targets = np.sin(xrange_inputs)\n",
|
||||
"targets = jnp.sin(xrange_inputs)\n",
|
||||
"predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n",
|
||||
"plt.plot(xrange_inputs, predictions, label='pre-update predictions')\n",
|
||||
"plt.plot(xrange_inputs, targets, label='target')\n",
|
||||
"\n",
|
||||
"x1 = onp.random.uniform(low=-5., high=5., size=(10,1))\n",
|
||||
"y1 = 1. * onp.sin(x1 + 0.)\n",
|
||||
"x1 = np.random.uniform(low=-5., high=5., size=(10,1))\n",
|
||||
"y1 = 1. * np.sin(x1 + 0.)\n",
|
||||
"\n",
|
||||
"for i in range(1,3):\n",
|
||||
" net_params = inner_update(net_params, x1, y1)\n",
|
||||
@ -851,8 +851,8 @@
|
||||
},
|
||||
"source": [
|
||||
"# Comparison of maml_loss for task batch size = 1 vs. task batch size = 8\n",
|
||||
"plt.plot(onp.convolve(np_maml_loss, [.05]*20), label='task_batch=1')\n",
|
||||
"plt.plot(onp.convolve(np_batched_maml_loss, [.05]*20), label='task_batch=4')\n",
|
||||
"plt.plot(np.convolve(np_maml_loss, [.05]*20), label='task_batch=1')\n",
|
||||
"plt.plot(np.convolve(np_batched_maml_loss, [.05]*20), label='task_batch=4')\n",
|
||||
"plt.ylim(0., 1e-1)\n",
|
||||
"plt.legend()"
|
||||
],
|
||||
|
@ -60,7 +60,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import grad, jit, vmap\n",
|
||||
"from jax import random"
|
||||
]
|
||||
@ -132,17 +132,17 @@
|
||||
"from jax.scipy.special import logsumexp\n",
|
||||
"\n",
|
||||
"def relu(x):\n",
|
||||
" return np.maximum(0, x)\n",
|
||||
" return jnp.maximum(0, x)\n",
|
||||
"\n",
|
||||
"def predict(params, image):\n",
|
||||
" # per-example predictions\n",
|
||||
" activations = image\n",
|
||||
" for w, b in params[:-1]:\n",
|
||||
" outputs = np.dot(w, activations) + b\n",
|
||||
" outputs = jnp.dot(w, activations) + b\n",
|
||||
" activations = relu(outputs)\n",
|
||||
" \n",
|
||||
" final_w, final_b = params[-1]\n",
|
||||
" logits = np.dot(final_w, activations) + final_b\n",
|
||||
" logits = jnp.dot(final_w, activations) + final_b\n",
|
||||
" return logits - logsumexp(logits)"
|
||||
]
|
||||
},
|
||||
@ -267,18 +267,18 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def one_hot(x, k, dtype=np.float32):\n",
|
||||
"def one_hot(x, k, dtype=jnp.float32):\n",
|
||||
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
|
||||
" return np.array(x[:, None] == np.arange(k), dtype)\n",
|
||||
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
|
||||
" \n",
|
||||
"def accuracy(params, images, targets):\n",
|
||||
" target_class = np.argmax(targets, axis=1)\n",
|
||||
" predicted_class = np.argmax(batched_predict(params, images), axis=1)\n",
|
||||
" return np.mean(predicted_class == target_class)\n",
|
||||
" target_class = jnp.argmax(targets, axis=1)\n",
|
||||
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",
|
||||
" return jnp.mean(predicted_class == target_class)\n",
|
||||
"\n",
|
||||
"def loss(params, images, targets):\n",
|
||||
" preds = batched_predict(params, images)\n",
|
||||
" return -np.mean(preds * targets)\n",
|
||||
" return -jnp.mean(preds * targets)\n",
|
||||
"\n",
|
||||
"@jit\n",
|
||||
"def update(params, x, y):\n",
|
||||
@ -325,12 +325,12 @@
|
||||
"\n",
|
||||
"# Full train set\n",
|
||||
"train_images, train_labels = train_data['image'], train_data['label']\n",
|
||||
"train_images = np.reshape(train_images, (len(train_images), num_pixels))\n",
|
||||
"train_images = jnp.reshape(train_images, (len(train_images), num_pixels))\n",
|
||||
"train_labels = one_hot(train_labels, num_labels)\n",
|
||||
"\n",
|
||||
"# Full test set\n",
|
||||
"test_images, test_labels = test_data['image'], test_data['label']\n",
|
||||
"test_images = np.reshape(test_images, (len(test_images), num_pixels))\n",
|
||||
"test_images = jnp.reshape(test_images, (len(test_images), num_pixels))\n",
|
||||
"test_labels = one_hot(test_labels, num_labels)"
|
||||
]
|
||||
},
|
||||
@ -429,7 +429,7 @@
|
||||
"for epoch in range(num_epochs):\n",
|
||||
" start_time = time.time()\n",
|
||||
" for x, y in get_train_batches():\n",
|
||||
" x = np.reshape(x, (len(x), num_pixels))\n",
|
||||
" x = jnp.reshape(x, (len(x), num_pixels))\n",
|
||||
" y = one_hot(y, num_labels)\n",
|
||||
" params = update(params, x, y)\n",
|
||||
" epoch_time = time.time() - start_time\n",
|
||||
|
@ -37,7 +37,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import grad, jit, vmap\n",
|
||||
"from jax import random"
|
||||
],
|
||||
@ -98,8 +98,8 @@
|
||||
},
|
||||
"source": [
|
||||
"size = 3000\n",
|
||||
"x = random.normal(key, (size, size), dtype=np.float32)\n",
|
||||
"%timeit np.dot(x, x.T).block_until_ready() # runs on the GPU"
|
||||
"x = random.normal(key, (size, size), dtype=jnp.float32)\n",
|
||||
"%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -124,9 +124,9 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import numpy as onp # original CPU-backed NumPy\n",
|
||||
"x = onp.random.normal(size=(size, size)).astype(onp.float32)\n",
|
||||
"%timeit np.dot(x, x.T).block_until_ready()"
|
||||
"import numpy as np\n",
|
||||
"x = np.random.normal(size=(size, size)).astype(np.float32)\n",
|
||||
"%timeit jnp.dot(x, x.T).block_until_ready()"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -151,9 +151,9 @@
|
||||
"source": [
|
||||
"from jax import device_put\n",
|
||||
"\n",
|
||||
"x = onp.random.normal(size=(size, size)).astype(onp.float32)\n",
|
||||
"x = np.random.normal(size=(size, size)).astype(np.float32)\n",
|
||||
"x = device_put(x)\n",
|
||||
"%timeit np.dot(x, x.T).block_until_ready()"
|
||||
"%timeit jnp.dot(x, x.T).block_until_ready()"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -186,8 +186,8 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"x = onp.random.normal(size=(size, size)).astype(onp.float32)\n",
|
||||
"%timeit onp.dot(x, x.T)"
|
||||
"x = np.random.normal(size=(size, size)).astype(np.float32)\n",
|
||||
"%timeit np.dot(x, x.T)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -237,7 +237,7 @@
|
||||
},
|
||||
"source": [
|
||||
"def selu(x, alpha=1.67, lmbda=1.05):\n",
|
||||
" return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)\n",
|
||||
" return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n",
|
||||
"\n",
|
||||
"x = random.normal(key, (1000000,))\n",
|
||||
"%timeit selu(x).block_until_ready()"
|
||||
@ -290,9 +290,9 @@
|
||||
},
|
||||
"source": [
|
||||
"def sum_logistic(x):\n",
|
||||
" return np.sum(1.0 / (1.0 + np.exp(-x)))\n",
|
||||
" return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))\n",
|
||||
"\n",
|
||||
"x_small = np.arange(3.)\n",
|
||||
"x_small = jnp.arange(3.)\n",
|
||||
"derivative_fn = grad(sum_logistic)\n",
|
||||
"print(derivative_fn(x_small))"
|
||||
],
|
||||
@ -319,8 +319,8 @@
|
||||
"source": [
|
||||
"def first_finite_differences(f, x):\n",
|
||||
" eps = 1e-3\n",
|
||||
" return np.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)\n",
|
||||
" for v in np.eye(len(x))])\n",
|
||||
" return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)\n",
|
||||
" for v in jnp.eye(len(x))])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(first_finite_differences(sum_logistic, x_small))"
|
||||
@ -418,7 +418,7 @@
|
||||
"batched_x = random.normal(key, (10, 100))\n",
|
||||
"\n",
|
||||
"def apply_matrix(v):\n",
|
||||
" return np.dot(mat, v)"
|
||||
" return jnp.dot(mat, v)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -442,7 +442,7 @@
|
||||
},
|
||||
"source": [
|
||||
"def naively_batched_apply_matrix(v_batched):\n",
|
||||
" return np.stack([apply_matrix(v) for v in v_batched])\n",
|
||||
" return jnp.stack([apply_matrix(v) for v in v_batched])\n",
|
||||
"\n",
|
||||
"print('Naively batched')\n",
|
||||
"%timeit naively_batched_apply_matrix(batched_x).block_until_ready()"
|
||||
@ -457,7 +457,7 @@
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"We know how to batch this operation manually. In this case, `np.dot` handles extra batch dimensions transparently."
|
||||
"We know how to batch this operation manually. In this case, `jnp.dot` handles extra batch dimensions transparently."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -470,7 +470,7 @@
|
||||
"source": [
|
||||
"@jit\n",
|
||||
"def batched_apply_matrix(v_batched):\n",
|
||||
" return np.dot(v_batched, mat.T)\n",
|
||||
" return jnp.dot(v_batched, mat.T)\n",
|
||||
"\n",
|
||||
"print('Manually batched')\n",
|
||||
"%timeit batched_apply_matrix(batched_x).block_until_ready()"
|
||||
|
@ -63,14 +63,14 @@
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"%matplotlib inline\n",
|
||||
"import numpy as onp\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"from sklearn.datasets import make_swiss_roll\n",
|
||||
"\n",
|
||||
"def sample_batch(size, noise=1.0):\n",
|
||||
" x, _= make_swiss_roll(size, noise=noise)\n",
|
||||
" x = x[:, [0, 2]] / 10.0\n",
|
||||
" return onp.array(x)\n",
|
||||
" return np.array(x)\n",
|
||||
"\n",
|
||||
"plt.scatter(*sample_batch(10**4).T, alpha=0.1)"
|
||||
],
|
||||
@ -131,7 +131,7 @@
|
||||
},
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"import jax.numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax.experimental import optimizers\n",
|
||||
"from jax.experimental import stax\n",
|
||||
"from functools import partial\n",
|
||||
@ -167,10 +167,10 @@
|
||||
" # we use jax.vmap to vectorize jacobian function along batch dimension\n",
|
||||
" batch_jacobian = jax.vmap(partial(jacobian, net_params))(inputs) # [batch, dim, dim]\n",
|
||||
" \n",
|
||||
" trace_jacobian = np.trace(batch_jacobian, axis1=1, axis2=2)\n",
|
||||
" output_norm_sq = np.square(net_apply(net_params, inputs)).sum(axis=1)\n",
|
||||
" trace_jacobian = jnp.trace(batch_jacobian, axis1=1, axis2=2)\n",
|
||||
" output_norm_sq = jnp.square(net_apply(net_params, inputs)).sum(axis=1)\n",
|
||||
" \n",
|
||||
" return np.mean(trace_jacobian + 1/2 * output_norm_sq)\n",
|
||||
" return jnp.mean(trace_jacobian + 1/2 * output_norm_sq)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
@ -243,16 +243,16 @@
|
||||
" clear_output(True)\n",
|
||||
" plt.figure(figsize=[16, 8])\n",
|
||||
" plt.subplot(1, 2, 1)\n",
|
||||
" plt.title(\"mean loss = %.3f\" % np.mean(np.array(loss_history[-32:])))\n",
|
||||
" plt.scatter(np.arange(len(loss_history)), loss_history)\n",
|
||||
" plt.title(\"mean loss = %.3f\" % jnp.mean(jnp.array(loss_history[-32:])))\n",
|
||||
" plt.scatter(jnp.arange(len(loss_history)), loss_history)\n",
|
||||
" plt.grid()\n",
|
||||
" \n",
|
||||
" plt.subplot(1, 2, 2)\n",
|
||||
" net_params = get_params(opt_state)\n",
|
||||
" xx = np.stack(np.meshgrid(np.linspace(-1.5, 2.0, 50), np.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)\n",
|
||||
" xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 2.0, 50), jnp.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)\n",
|
||||
" scores = net_apply(net_params, xx)\n",
|
||||
" scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
|
||||
" scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)\n",
|
||||
" scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
|
||||
" scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)\n",
|
||||
"\n",
|
||||
" plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')\n",
|
||||
" plt.xlim(-1.5, 2.0)\n",
|
||||
@ -301,10 +301,10 @@
|
||||
"plt.figure(figsize=[16, 16])\n",
|
||||
"\n",
|
||||
"net_params = get_params(opt_state)\n",
|
||||
"xx = np.stack(np.meshgrid(np.linspace(-1.5, 1.5, 50), np.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)\n",
|
||||
"xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 1.5, 50), jnp.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)\n",
|
||||
"scores = net_apply(net_params, xx)\n",
|
||||
"scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
|
||||
"scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)\n",
|
||||
"scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
|
||||
"scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)\n",
|
||||
"\n",
|
||||
"plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')\n",
|
||||
"plt.scatter(*sample_batch(10_000).T, alpha=0.25)"
|
||||
@ -412,11 +412,11 @@
|
||||
" for t in range(num_steps):\n",
|
||||
" key, subkey = jax.random.split(key)\n",
|
||||
" z_t = jax.random.normal(subkey, shape=x_t.shape)\n",
|
||||
" x_t = x_t + eps / 2 * net_apply(net_params, x_t) + np.sqrt(eps) * temperature * z_t\n",
|
||||
" x_t = x_t + eps / 2 * net_apply(net_params, x_t) + jnp.sqrt(eps) * temperature * z_t\n",
|
||||
" x_sequence.append(x_t)\n",
|
||||
" eps *= eps_decay\n",
|
||||
" \n",
|
||||
" return np.stack(x_sequence)"
|
||||
" return jnp.stack(x_sequence)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -438,7 +438,7 @@
|
||||
"key = jax.random.PRNGKey(42)\n",
|
||||
"net_params = get_params(opt_state)\n",
|
||||
"\n",
|
||||
"for x_initial in np.array([[-1.5, -1.5], [0, 0], [1.5, 0]]):\n",
|
||||
"for x_initial in jnp.array([[-1.5, -1.5], [0, 0], [1.5, 0]]):\n",
|
||||
" key, subkey = jax.random.split(key)\n",
|
||||
" # sample x sequence\n",
|
||||
" xx = sample_langevin(x_initial, key=subkey, net_params=net_params)\n",
|
||||
@ -446,16 +446,16 @@
|
||||
"\n",
|
||||
" # draw arrows for each mcmc step\n",
|
||||
" deltas = (xx[1:] - xx[:-1])\n",
|
||||
" deltas = deltas - deltas / np.linalg.norm(deltas, keepdims=True, axis=-1) * 0.04\n",
|
||||
" deltas = deltas - deltas / jnp.linalg.norm(deltas, keepdims=True, axis=-1) * 0.04\n",
|
||||
" for i, arrow in enumerate(deltas):\n",
|
||||
" plt.arrow(xx[i][0], xx[i][1], arrow[0], arrow[1], width=1e-4, head_width=2e-2, color=\"orange\")\n",
|
||||
" \n",
|
||||
"# plot data points and gradients\n",
|
||||
"plt.plot()\n",
|
||||
"xx = np.stack(np.meshgrid(np.linspace(-1.5, 1.5, 50), np.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)\n",
|
||||
"xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 1.5, 50), jnp.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)\n",
|
||||
"scores = net_apply(net_params, xx)\n",
|
||||
"scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
|
||||
"scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)\n",
|
||||
"scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
|
||||
"scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)\n",
|
||||
"plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')\n",
|
||||
"plt.scatter(*sample_batch(10_000).T, alpha=0.025)"
|
||||
],
|
||||
@ -516,7 +516,7 @@
|
||||
"@jax.jit\n",
|
||||
"def compute_ssm_loss(net_params, inputs, key):\n",
|
||||
" apply = jax.jit(partial(net_apply, net_params))\n",
|
||||
" batch_dot = partial(np.einsum, 'bu,bu->b')\n",
|
||||
" batch_dot = partial(jnp.einsum, 'bu,bu->b')\n",
|
||||
" \n",
|
||||
" # generate random vectors from N(0, I)\n",
|
||||
" v = jax.random.normal(key, shape=inputs.shape)\n",
|
||||
@ -524,7 +524,7 @@
|
||||
" # predict score and comput jacobian of score times v\n",
|
||||
" score, jac_v = jax.jvp(apply, [inputs], [v])\n",
|
||||
" \n",
|
||||
" return np.mean(batch_dot(v, jac_v) + 1/2 * batch_dot(v, score) ** 2)\n",
|
||||
" return jnp.mean(batch_dot(v, jac_v) + 1/2 * batch_dot(v, score) ** 2)\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"def train_step(step_i, opt_state, batch, key):\n",
|
||||
@ -588,16 +588,16 @@
|
||||
" clear_output(True)\n",
|
||||
" plt.figure(figsize=[16, 8])\n",
|
||||
" plt.subplot(1, 2, 1)\n",
|
||||
" plt.title(\"mean loss = %.3f\" % np.mean(np.array(loss_history[-32:])))\n",
|
||||
" plt.scatter(np.arange(len(loss_history)), loss_history)\n",
|
||||
" plt.title(\"mean loss = %.3f\" % jnp.mean(jnp.array(loss_history[-32:])))\n",
|
||||
" plt.scatter(jnp.arange(len(loss_history)), loss_history)\n",
|
||||
" plt.grid()\n",
|
||||
" \n",
|
||||
" plt.subplot(1, 2, 2)\n",
|
||||
" net_params = get_params(opt_state)\n",
|
||||
" xx = np.stack(np.meshgrid(np.linspace(-1.5, 2.0, 50), np.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)\n",
|
||||
" xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 2.0, 50), jnp.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)\n",
|
||||
" scores = net_apply(net_params, xx)\n",
|
||||
" scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
|
||||
" scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)\n",
|
||||
" scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
|
||||
" scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)\n",
|
||||
"\n",
|
||||
" plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')\n",
|
||||
" plt.xlim(-1.5, 2.0)\n",
|
||||
@ -644,7 +644,7 @@
|
||||
},
|
||||
"source": [
|
||||
"from sklearn.datasets import load_digits\n",
|
||||
"import numpy as old_np\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"X, _ = load_digits(return_X_y=True)\n",
|
||||
"\n",
|
||||
@ -654,8 +654,8 @@
|
||||
" \n",
|
||||
"\n",
|
||||
"def sample_batch(size, noise=0.1):\n",
|
||||
" ix = old_np.random.randint(0, len(X), size=size)\n",
|
||||
" return np.array(X[ix] / 16 + noise * old_np.random.randn(size, 64))"
|
||||
" ix = np.random.randint(0, len(X), size=size)\n",
|
||||
" return jnp.array(X[ix] / 16 + noise * np.random.randn(size, 64))"
|
||||
],
|
||||
"execution_count": 32,
|
||||
"outputs": [
|
||||
@ -721,8 +721,8 @@
|
||||
" \n",
|
||||
" if i % 500 == 0:\n",
|
||||
" clear_output(True)\n",
|
||||
" plt.title(\"mean loss = %.3f\" % np.mean(np.array(loss_history[-32:])))\n",
|
||||
" plt.scatter(np.arange(len(loss_history)), loss_history)\n",
|
||||
" plt.title(\"mean loss = %.3f\" % jnp.mean(jnp.array(loss_history[-32:])))\n",
|
||||
" plt.scatter(jnp.arange(len(loss_history)), loss_history)\n",
|
||||
" plt.show()"
|
||||
],
|
||||
"execution_count": 35,
|
||||
|
@ -59,12 +59,12 @@
|
||||
"import jax\n",
|
||||
"\n",
|
||||
"from jax import lax\n",
|
||||
"from jax import numpy as np\n",
|
||||
"from jax import scipy\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import jax.scipy as jsp\n",
|
||||
"from jax import random\n",
|
||||
"\n",
|
||||
"import numpy as onp\n",
|
||||
"import scipy as oscipy"
|
||||
"import numpy as np\n",
|
||||
"import scipy as sp"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -87,14 +87,14 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"onp.random.seed(10009)\n",
|
||||
"np.random.seed(10009)\n",
|
||||
"\n",
|
||||
"num_features = 10\n",
|
||||
"num_points = 100\n",
|
||||
"\n",
|
||||
"true_beta = onp.random.randn(num_features).astype(np.float32)\n",
|
||||
"all_x = onp.random.randn(num_points, num_features).astype(np.float32)\n",
|
||||
"y = (onp.random.rand(num_points) < oscipy.special.expit(all_x.dot(true_beta))).astype(np.int32)"
|
||||
"true_beta = np.random.randn(num_features).astype(jnp.float32)\n",
|
||||
"all_x = np.random.randn(num_points, num_features).astype(jnp.float32)\n",
|
||||
"y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -165,9 +165,9 @@
|
||||
"source": [
|
||||
"def log_joint(beta):\n",
|
||||
" result = 0.\n",
|
||||
" # Note that no `axis` parameter is provided to `np.sum`.\n",
|
||||
" result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.))\n",
|
||||
" result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))\n",
|
||||
" # Note that no `axis` parameter is provided to `jnp.sum`.\n",
|
||||
" result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))\n",
|
||||
" result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))\n",
|
||||
" return result"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -185,7 +185,7 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"log_joint(onp.random.randn(num_features))"
|
||||
"log_joint(np.random.randn(num_features))"
|
||||
],
|
||||
"execution_count": 13,
|
||||
"outputs": [
|
||||
@ -218,9 +218,9 @@
|
||||
"# This doesn't work, because we didn't write `log_prob()` to handle batching.\n",
|
||||
"try:\n",
|
||||
" batch_size = 10\n",
|
||||
" batched_test_beta = onp.random.randn(batch_size, num_features)\n",
|
||||
" batched_test_beta = np.random.randn(batch_size, num_features)\n",
|
||||
"\n",
|
||||
" log_joint(onp.random.randn(batch_size, num_features))\n",
|
||||
" log_joint(np.random.randn(batch_size, num_features))\n",
|
||||
"except ValueError as e:\n",
|
||||
" print(\"Caught expected exception \" + str(e))"
|
||||
],
|
||||
@ -258,12 +258,12 @@
|
||||
" # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis\n",
|
||||
" # or setting it incorrectly yields an error; at worst, it silently changes the\n",
|
||||
" # semantics of the model.\n",
|
||||
" result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.),\n",
|
||||
" result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),\n",
|
||||
" axis=-1)\n",
|
||||
" # Note the multiple transposes. Getting this right is not rocket science,\n",
|
||||
" # but it's also not totally mindless. (I didn't get it right on the first\n",
|
||||
" # try.)\n",
|
||||
" result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta.T).T)),\n",
|
||||
" result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),\n",
|
||||
" axis=-1)\n",
|
||||
" return result"
|
||||
],
|
||||
@ -283,7 +283,7 @@
|
||||
},
|
||||
"source": [
|
||||
"batch_size = 10\n",
|
||||
"batched_test_beta = onp.random.randn(batch_size, num_features)\n",
|
||||
"batched_test_beta = np.random.randn(batch_size, num_features)\n",
|
||||
"\n",
|
||||
"batched_log_joint(batched_test_beta)"
|
||||
],
|
||||
@ -383,9 +383,9 @@
|
||||
"@jax.jit\n",
|
||||
"def log_joint(beta):\n",
|
||||
" result = 0.\n",
|
||||
" # Note that no `axis` parameter is provided to `np.sum`.\n",
|
||||
" result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=10.))\n",
|
||||
" result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))\n",
|
||||
" # Note that no `axis` parameter is provided to `jnp.sum`.\n",
|
||||
" result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))\n",
|
||||
" result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
"batched_log_joint = jax.jit(jax.vmap(log_joint))"
|
||||
@ -412,8 +412,8 @@
|
||||
},
|
||||
"source": [
|
||||
"def elbo(beta_loc, beta_log_scale, epsilon):\n",
|
||||
" beta_sample = beta_loc + np.exp(beta_log_scale) * epsilon\n",
|
||||
" return np.mean(batched_log_joint(beta_sample), 0) + np.sum(beta_log_scale - 0.5 * onp.log(2*onp.pi))\n",
|
||||
" beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n",
|
||||
" return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n",
|
||||
" \n",
|
||||
"elbo = jax.jit(elbo, static_argnums=(1, 2))\n",
|
||||
"elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))"
|
||||
@ -452,8 +452,8 @@
|
||||
"\n",
|
||||
"key = random.PRNGKey(10003)\n",
|
||||
"\n",
|
||||
"beta_loc = np.zeros(num_features, np.float32)\n",
|
||||
"beta_log_scale = np.zeros(num_features, np.float32)\n",
|
||||
"beta_loc = jnp.zeros(num_features, jnp.float32)\n",
|
||||
"beta_log_scale = jnp.zeros(num_features, jnp.float32)\n",
|
||||
"\n",
|
||||
"step_size = 0.01\n",
|
||||
"batch_size = 128\n",
|
||||
@ -603,8 +603,8 @@
|
||||
"source": [
|
||||
"figure(figsize=(7, 7))\n",
|
||||
"plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n",
|
||||
"plot(true_beta, beta_loc + 2*np.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n",
|
||||
"plot(true_beta, beta_loc - 2*np.exp(beta_log_scale), 'r.')\n",
|
||||
"plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n",
|
||||
"plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n",
|
||||
"plot_scale = 3\n",
|
||||
"plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n",
|
||||
"xlabel('True beta')\n",
|
||||
|
@ -94,7 +94,7 @@ reference cycles.
|
||||
Here is a simple example::
|
||||
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
from jax import numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
# The structured value to be transformed
|
||||
value_structured = [1., (2., 3.)]
|
||||
@ -129,7 +129,7 @@ treated as leaves::
|
||||
(1., {'b': 2., 'a': 3.}),
|
||||
1.,
|
||||
None,
|
||||
np.zeros(2),
|
||||
jnp.zeros(2),
|
||||
Point(1., 2.)
|
||||
]
|
||||
def show_example(structured):
|
||||
|
@ -9,9 +9,9 @@ surprising bugs where a silent rank promotion masks an underlying shape error.
|
||||
|
||||
Here's an example of rank promotion:
|
||||
|
||||
>>> import numpy as onp
|
||||
>>> x = onp.arange(12).reshape(4, 3)
|
||||
>>> y = onp.array([0, 1, 0])
|
||||
>>> import numpy as np
|
||||
>>> x = np.arange(12).reshape(4, 3)
|
||||
>>> y = np.array([0, 1, 0])
|
||||
>>> x + y
|
||||
array([[ 0, 2, 2],
|
||||
[ 3, 5, 5],
|
||||
|
@ -60,21 +60,21 @@ following table, where, for example
|
||||
</table><p>
|
||||
|
||||
.. The table above was generated by the following Python code.
|
||||
import numpy as onp
|
||||
from jax import numpy as np
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
types = [onp.bool_, onp.uint8, onp.uint16, onp.uint32, onp.uint64,
|
||||
onp.int8, onp.int16, onp.int32, onp.int64,
|
||||
np.bfloat16, onp.float16, onp.float32, onp.float64,
|
||||
onp.complex64, onp.complex128]
|
||||
types = [np.bool_, np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.int8, np.int16, np.int32, np.int64,
|
||||
jnp.bfloat16, np.float16, np.float32, np.float64,
|
||||
np.complex64, np.complex128]
|
||||
|
||||
def name(d):
|
||||
d = onp.dtype(d)
|
||||
if d == onp.dtype(np.bfloat16):
|
||||
d = np.dtype(d)
|
||||
if d == np.dtype(jnp.bfloat16):
|
||||
return "bf"
|
||||
return "{}{}".format(
|
||||
d.kind,
|
||||
d.itemsize // 2 if onp.issubdtype(d, onp.complexfloating) else d.itemsize)
|
||||
d.itemsize // 2 if np.issubdtype(d, np.complexfloating) else d.itemsize)
|
||||
|
||||
out = "<tr><th></th>"
|
||||
for t in types:
|
||||
@ -84,8 +84,8 @@ following table, where, for example
|
||||
for t1 in types:
|
||||
out += "<tr><td>{}</td>".format(name(t1))
|
||||
for t2 in types:
|
||||
t = np.promote_types(t1, t2)
|
||||
different = np.bfloat16 in (t1, t2) or t != onp.promote_types(t1, t2)
|
||||
t = jnp.promote_types(t1, t2)
|
||||
different = jnp.bfloat16 in (t1, t2) or t != np.promote_types(t1, t2)
|
||||
out += "<td{}>{}</td>".format(" class=\"d\"" if different else "", name(t))
|
||||
out += "</tr>\n"
|
||||
|
||||
|
@ -28,7 +28,7 @@ pairs can be composed in series using `stax.serial` or in parallel using
|
||||
Here’s an example:
|
||||
|
||||
```python
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
|
||||
@ -48,7 +48,7 @@ in_shape = (-1, 28, 28, 1)
|
||||
out_shape, net_params = net_init(rng, in_shape)
|
||||
|
||||
# Apply network to dummy inputs
|
||||
inputs = np.zeros((128, 28, 28, 1))
|
||||
inputs = jnp.zeros((128, 28, 28, 1))
|
||||
predictions = net_apply(net_params, inputs)
|
||||
```
|
||||
|
||||
@ -74,7 +74,7 @@ from jax import jit, grad
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
predictions = net_apply(params, inputs)
|
||||
return np.sum((predictions - targets)**2)
|
||||
return jnp.sum((predictions - targets)**2)
|
||||
|
||||
# Use optimizers to set optimizer initialization and update functions
|
||||
opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9)
|
||||
@ -87,7 +87,7 @@ def step(i, opt_state, batch):
|
||||
return opt_update(i, g, opt_state)
|
||||
|
||||
# Dummy input data stream
|
||||
data_generator = ((np.zeros((128, 28, 28, 1)), np.zeros((128, 10)))
|
||||
data_generator = ((jnp.zeros((128, 28, 28, 1)), jnp.zeros((128, 10)))
|
||||
for _ in range(10))
|
||||
|
||||
# Optimize parameters in a loop
|
||||
|
Loading…
x
Reference in New Issue
Block a user