Change onp/np to np/jnp in docs & notebooks (#3760)

This commit is contained in:
Jake Vanderplas 2020-07-15 13:17:38 -07:00 committed by GitHub
parent 150d028d9d
commit 05904faf0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 547 additions and 547 deletions

View File

@ -52,18 +52,18 @@ bugs](https://github.com/google/jax/issues), and letting us know what you
think!
```python
import jax.numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
def predict(params, inputs):
for W, b in params:
outputs = np.dot(inputs, W) + b
inputs = np.tanh(outputs)
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs)
return outputs
def logprob_fun(params, inputs, targets):
preds = predict(params, inputs)
return np.sum((preds - targets)**2)
return jnp.sum((preds - targets)**2)
grad_fun = jit(grad(logprob_fun)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grads
@ -114,10 +114,10 @@ for reverse-mode gradients:
```python
from jax import grad
import jax.numpy as np
import jax.numpy as jnp
def tanh(x): # Define a function
y = np.exp(-2.0 * x)
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function
@ -176,14 +176,14 @@ You can use XLA to compile your functions end-to-end with
used either as an `@jit` decorator or as a higher-order function.
```python
import jax.numpy as np
import jax.numpy as jnp
from jax import jit
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = np.ones((5000, 5000))
x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
@ -213,19 +213,19 @@ function:
def predict(params, input_vec):
assert input_vec.ndim == 1
for W, b in params:
output_vec = np.dot(W, input_vec) + b # `input_vec` on the right-hand side!
input_vec = np.tanh(output_vec)
output_vec = jnp.dot(W, input_vec) + b # `input_vec` on the right-hand side!
input_vec = jnp.tanh(output_vec)
return output_vec
```
We often instead write `np.dot(inputs, W)` to allow for a batch dimension on the
We often instead write `jnp.dot(inputs, W)` to allow for a batch dimension on the
left side of `inputs`, but weve written this particular prediction function to
apply only to single input vectors. If we wanted to apply this function to a
batch of inputs at once, semantically we could just write
```python
from functools import partial
predictions = np.stack(list(map(partial(predict, params), input_batch)))
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
```
But pushing one example through the network at a time would be slow! Its better
@ -273,17 +273,17 @@ Here's an example on an 8-GPU machine:
```python
from jax import random, pmap
import jax.numpy as np
import jax.numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: np.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print(pmap(np.mean)(result))
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
```
@ -299,7 +299,7 @@ from jax import lax
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(np.arange(4.)))
print(normalize(jnp.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ]
```
@ -313,11 +313,11 @@ from jax import grad
@pmap
def f(x):
y = np.sin(x)
y = jnp.sin(x)
@pmap
def g(z):
return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum()
return grad(lambda w: np.sum(g(w)))(x)
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
print(f(x))
# [[ 0. , -0.7170853 ],
@ -325,7 +325,7 @@ print(f(x))
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print(grad(lambda x: np.sum(f(x)))(x))
print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],

View File

@ -59,7 +59,7 @@
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
"\n",
"key = random.PRNGKey(0)"
@ -92,7 +92,7 @@
},
"outputs": [],
"source": [
"y = np.dot(x, x)\n",
"y = jnp.dot(x, x)\n",
"print(y[0, 0])"
]
},
@ -134,7 +134,7 @@
},
"outputs": [],
"source": [
"np.dot(x, x.T)"
"jnp.dot(x, x.T)"
]
},
{
@ -147,7 +147,7 @@
},
"outputs": [],
"source": [
"print(np.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])"
"print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])"
]
},
{
@ -160,10 +160,10 @@
},
"outputs": [],
"source": [
"import numpy as onp\n",
"import numpy as np\n",
"\n",
"x_cpu = onp.array(x)\n",
"%timeit -n 1 -r 1 onp.dot(x_cpu, x_cpu)"
"x_cpu = np.array(x)\n",
"%timeit -n 1 -r 1 np.dot(x_cpu, x_cpu)"
]
},
{
@ -176,7 +176,7 @@
},
"outputs": [],
"source": [
"%timeit -n 5 -r 5 np.dot(x, x).block_until_ready()"
"%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()"
]
},
{
@ -262,14 +262,14 @@
"source": [
"def predict(params, inputs):\n",
" for W, b in params:\n",
" outputs = np.dot(inputs, W) + b\n",
" inputs = np.tanh(outputs)\n",
" outputs = jnp.dot(inputs, W) + b\n",
" inputs = jnp.tanh(outputs)\n",
" return outputs\n",
"\n",
"def loss(params, batch):\n",
" inputs, targets = batch\n",
" predictions = predict(params, inputs)\n",
" return np.sum((predictions - targets)**2)\n",
" return jnp.sum((predictions - targets)**2)\n",
"\n",
"\n",
"\n",
@ -459,7 +459,7 @@
},
"outputs": [],
"source": [
"grad(jit(grad(jit(grad(np.tanh)))))(1.0)"
"grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)"
]
},
{
@ -551,7 +551,7 @@
},
"outputs": [],
"source": [
"f(np.array([1., 2., 3.]), 5)"
"f(jnp.array([1., 2., 3.]), 5)"
]
},
{
@ -565,7 +565,7 @@
"outputs": [],
"source": [
"try:\n",
" g(np.array([1., 2., 3.]), 5)\n",
" g(jnp.array([1., 2., 3.]), 5)\n",
"except Exception as e:\n",
" print(e)\n",
" pass"
@ -594,7 +594,7 @@
},
"outputs": [],
"source": [
"g(np.array([1., 2., 3.]), 5)"
"g(jnp.array([1., 2., 3.]), 5)"
]
},
{
@ -630,7 +630,7 @@
},
"outputs": [],
"source": [
"print(vmap(lambda x: x**2)(np.arange(8)))"
"print(vmap(lambda x: x**2)(jnp.arange(8)))"
]
},
{
@ -645,7 +645,7 @@
"source": [
"from jax import make_jaxpr\n",
"\n",
"make_jaxpr(np.dot)(np.ones(8), np.ones(8))"
"make_jaxpr(jnp.dot)(jnp.ones(8), jnp.ones(8))"
]
},
{
@ -658,7 +658,7 @@
},
"outputs": [],
"source": [
"make_jaxpr(vmap(np.dot))(np.ones((10, 8)), np.ones((10, 8)))"
"make_jaxpr(vmap(jnp.dot))(jnp.ones((10, 8)), jnp.ones((10, 8)))"
]
},
{
@ -671,7 +671,7 @@
},
"outputs": [],
"source": [
"make_jaxpr(vmap(vmap(np.dot)))(np.ones((10, 10, 8)), np.ones((10, 10, 8)))"
"make_jaxpr(vmap(vmap(jnp.dot)))(jnp.ones((10, 10, 8)), jnp.ones((10, 10, 8)))"
]
},
{
@ -734,7 +734,7 @@
},
"outputs": [],
"source": [
"y = pmap(lambda x: x ** 2)(np.arange(8))\n",
"y = pmap(lambda x: x ** 2)(jnp.arange(8))\n",
"print(y)"
]
},
@ -791,8 +791,8 @@
"source": [
"keys = random.split(random.PRNGKey(0), 8)\n",
"mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n",
"result = pmap(np.dot)(mats, mats)\n",
"print(pmap(np.mean)(result))"
"result = pmap(jnp.dot)(mats, mats)\n",
"print(pmap(jnp.mean)(result))"
]
},
{
@ -805,7 +805,7 @@
},
"outputs": [],
"source": [
"timeit -n 5 -r 5 pmap(np.dot)(mats, mats).block_until_ready()"
"timeit -n 5 -r 5 pmap(jnp.dot)(mats, mats).block_until_ready()"
]
},
{
@ -835,7 +835,7 @@
"def normalize(x):\n",
" return x / psum(x, 'i')\n",
"\n",
"print(normalize(np.arange(8.)))"
"print(normalize(jnp.arange(8.)))"
]
},
{
@ -856,7 +856,7 @@
" total_sum = psum(x, ('rows', 'cols'))\n",
" return row_sum, col_sum, total_sum\n",
"\n",
"x = np.arange(8.).reshape((4, 2))\n",
"x = jnp.arange(8.).reshape((4, 2))\n",
"a, b, c = f(x)\n",
"\n",
"print(\"input:\\n\", x)\n",
@ -907,11 +907,11 @@
"source": [
"@pmap\n",
"def f(x):\n",
" y = np.sin(x)\n",
" y = jnp.sin(x)\n",
" @pmap\n",
" def g(z):\n",
" return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum()\n",
" return grad(lambda w: np.sum(g(w)))(x)\n",
" return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n",
" return grad(lambda w: jnp.sum(g(w)))(x)\n",
" \n",
"f(x)"
]
@ -926,7 +926,7 @@
},
"outputs": [],
"source": [
"grad(lambda x: np.sum(f(x)))(x)"
"grad(lambda x: jnp.sum(f(x)))(x)"
]
},
{
@ -956,16 +956,16 @@
"@curry\n",
"def jacfwd(fun, x):\n",
" pushfwd = partial(jvp, fun, (x,)) # jvp!\n",
" std_basis = np.eye(onp.size(x)).reshape((-1,) + np.shape(x)),\n",
" std_basis = jnp.eye(np.size(x)).reshape((-1,) + jnp.shape(x)),\n",
" y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis) # vmap!\n",
" return jac_flat.reshape(np.shape(y) + np.shape(x))\n",
" return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))\n",
"\n",
"@curry\n",
"def jacrev(fun, x):\n",
" y, pullback = vjp(fun, x) # vjp!\n",
" std_basis = np.eye(onp.size(y)).reshape((-1,) + np.shape(y))\n",
" std_basis = jnp.eye(np.size(y)).reshape((-1,) + jnp.shape(y))\n",
" jac_flat, = vmap(pullback)(std_basis) # vmap!\n",
" return jac_flat.reshape(np.shape(y) + np.shape(x))\n",
" return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))\n",
"\n",
"def hessian(fun):\n",
" return jit(jacfwd(jacrev(fun))) # jit!"

View File

@ -80,9 +80,9 @@
"source": [
"import io\n",
"from functools import partial\n",
"import numpy as onp\n",
"import numpy as np\n",
"import jax\n",
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import vmap, jit, grad, ops, lax, config\n",
"from jax import random as jr\n",
"\n",
@ -130,15 +130,15 @@
" https://en.wikipedia.org/wiki/Xiaolin_Wu's_line_algorithm\n",
" \"\"\"\n",
"\n",
" ipart = lambda x: np.floor(x).astype('int32')\n",
" ipart = lambda x: jnp.floor(x).astype('int32')\n",
" round_ = lambda x: ipart(x + 0.5).astype('int32')\n",
" fpart = lambda x: x - np.floor(x)\n",
" fpart = lambda x: x - jnp.floor(x)\n",
" rfpart = lambda x: 1 - fpart(x)\n",
"\n",
" def plot(im, x, y, c):\n",
" return ops.index_add(im, ops.index[x, y], c)\n",
"\n",
" steep = np.abs(y1 - y0) > np.abs(x1 - x0)\n",
" steep = jnp.abs(y1 - y0) > jnp.abs(x1 - x0)\n",
" cond_swap = lambda cond, x: lax.cond(cond, x, lambda x: (x[1], x[0]), x, lambda x: x)\n",
" \n",
" (x0, y0) = cond_swap(steep, (x0, y0))\n",
@ -149,7 +149,7 @@
"\n",
" dx = x1 - x0\n",
" dy = y1 - y0\n",
" gradient = np.where(dx == 0.0, 1.0, dy/dx)\n",
" gradient = jnp.where(dx == 0.0, 1.0, dy/dx)\n",
"\n",
" # handle first endpoint\n",
" xend = round_(x0)\n",
@ -211,12 +211,12 @@
" return im\n",
"\n",
"def img_adjust(data):\n",
" oim = onp.array(data)\n",
" hist, bin_edges = onp.histogram(oim.flat, bins=256*256)\n",
" oim = np.array(data)\n",
" hist, bin_edges = np.histogram(oim.flat, bins=256*256)\n",
" bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2\n",
" cdf = hist.cumsum()\n",
" cdf = cdf / float(cdf[-1])\n",
" return onp.interp(oim.flat, bin_centers, cdf).reshape(oim.shape)\n",
" return np.interp(oim.flat, bin_centers, cdf).reshape(oim.shape)\n",
"\n",
"def imify(arr, vmin=None, vmax=None, cmap=None, origin=None):\n",
" arr = img_adjust(arr)\n",
@ -239,16 +239,16 @@
" display_png(dat, raw=True)\n",
"\n",
"def pack_images(images, rows, cols):\n",
" shape = onp.shape(images)\n",
" shape = np.shape(images)\n",
" width, height, depth = shape[-3:]\n",
" images = onp.reshape(images, (-1, width, height, depth))\n",
" batch = onp.shape(images)[0]\n",
" rows = onp.minimum(rows, batch)\n",
" cols = onp.minimum(batch // rows, cols)\n",
" images = np.reshape(images, (-1, width, height, depth))\n",
" batch = np.shape(images)[0]\n",
" rows = np.minimum(rows, batch)\n",
" cols = np.minimum(batch // rows, cols)\n",
" images = images[:rows * cols]\n",
" images = onp.reshape(images, (rows, cols, width, height, depth))\n",
" images = onp.transpose(images, [0, 2, 1, 3, 4])\n",
" images = onp.reshape(images, [rows * width, cols * height, depth])\n",
" images = np.reshape(images, (rows, cols, width, height, depth))\n",
" images = np.transpose(images, [0, 2, 1, 3, 4])\n",
" images = np.reshape(images, [rows * width, cols * height, depth])\n",
" return images"
],
"execution_count": 0,
@ -281,7 +281,7 @@
"@jit\n",
"def f(state, t):\n",
" x, y, z = state\n",
" return np.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])"
" return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])"
],
"execution_count": 0,
"outputs": []
@ -343,8 +343,8 @@
"N = 40000\n",
"\n",
"# set initial condition\n",
"state0 = np.array([1., 1., 1.])\n",
"ys = np.zeros((N,) + state0.shape)\n",
"state0 = jnp.array([1., 1., 1.])\n",
"ys = jnp.zeros((N,) + state0.shape)\n",
"ys = ops.index_update(ys, ops.index[0], state0)\n",
"\n",
"# solve for N steps\n",
@ -368,7 +368,7 @@
"# fast, jitted plotting function\n",
"@partial(jax.jit, static_argnums=(2,3,4,5))\n",
"def jplotter(xs, zs, xlim, zlim, xN, zN):\n",
" im = np.zeros((xN, zN))\n",
" im = jnp.zeros((xN, zN))\n",
" xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN\n",
" zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN\n",
" def body_fun(i, im):\n",
@ -403,17 +403,17 @@
"N = 4000\n",
"\n",
"# set some initial conditions for each replicate\n",
"ys = np.zeros((N_dev, N, 3))\n",
"ys = jnp.zeros((N_dev, N, 3))\n",
"state0 = jr.uniform(jr.PRNGKey(1), \n",
" minval=-1., maxval=1.,\n",
" shape=(N_dev, 3))\n",
"state0 = state0 * np.array([18,18,1]) + np.array((0.,0.,10.))\n",
"state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))\n",
"ys = ops.index_update(ys, ops.index[:, 0], state0)\n",
"\n",
"# solve each replicate in parallel using `pmap` of rk4 solver:\n",
"ys = jax.pmap(rk4)(ys, \n",
" 0.004 * np.ones(N_dev), \n",
" N * np.ones(N_dev, dtype=onp.int32)\n",
" 0.004 * jnp.ones(N_dev), \n",
" N * jnp.ones(N_dev, dtype=np.int32)\n",
" ).block_until_ready()"
],
"execution_count": 0,
@ -430,7 +430,7 @@
"# parallel plotter using lexical closure and pmap'd core plotting function\n",
"def pplotter(_xs, _zs, xlim, zlim, xN, zN):\n",
" N_dev = _xs.shape[0]\n",
" im = np.zeros((N_dev, xN, zN))\n",
" im = jnp.zeros((N_dev, xN, zN))\n",
" @jax.pmap\n",
" def plotfn(im, xs, zs):\n",
" xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN\n",
@ -459,7 +459,7 @@
"plot_image(im[:,::-1].T, cmap='magma')\n",
"# below, plot combined ODE traces\n",
"ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN*4, zN*4)\n",
"plot_image(np.sum(ims, axis=0)[:,::-1].T, cmap='magma')"
"plot_image(jnp.sum(ims, axis=0)[:,::-1].T, cmap='magma')"
],
"execution_count": 0,
"outputs": []

View File

@ -81,7 +81,7 @@
"colab": {}
},
"source": [
"import jax.numpy as np"
"import jax.numpy as jnp"
],
"execution_count": 0,
"outputs": []
@ -137,7 +137,7 @@
"colab": {}
},
"source": [
"result = pmap(lambda x: x ** 2)(np.arange(7))\n",
"result = pmap(lambda x: x ** 2)(jnp.arange(7))\n",
"print(result)"
],
"execution_count": 0,
@ -163,11 +163,11 @@
"source": [
"from jax import vmap\n",
"\n",
"x = np.array([1., 2., 3.])\n",
"y = np.array([2., 4., 6.])\n",
"x = jnp.array([1., 2., 3.])\n",
"y = jnp.array([2., 4., 6.])\n",
"\n",
"print(vmap(np.add)(x, y))\n",
"print(pmap(np.add)(x, y))"
"print(vmap(jnp.add)(x, y))\n",
"print(pmap(jnp.add)(x, y))"
],
"execution_count": 0,
"outputs": []
@ -193,12 +193,12 @@
"from jax import make_jaxpr\n",
"\n",
"def f(x, y):\n",
" a = np.dot(x, y)\n",
" b = np.tanh(a)\n",
" a = jnp.dot(x, y)\n",
" b = jnp.tanh(a)\n",
" return b\n",
"\n",
"xs = np.ones((8, 2, 3))\n",
"ys = np.ones((8, 3, 4))\n",
"xs = jnp.ones((8, 2, 3))\n",
"ys = jnp.ones((8, 3, 4))\n",
"\n",
"print(\"f jaxpr\")\n",
"print(make_jaxpr(f)(xs[0], ys[0]))\n",
@ -236,7 +236,7 @@
"colab": {}
},
"source": [
"y = pmap(lambda x: x ** 2)(np.arange(8))\n",
"y = pmap(lambda x: x ** 2)(jnp.arange(8))\n",
"z = y / 2\n",
"print(z)"
],
@ -313,8 +313,8 @@
"colab": {}
},
"source": [
"import numpy as onp\n",
"onp.sin(y)"
"import numpy as np\n",
"np.sin(y)"
],
"execution_count": 0,
"outputs": []
@ -360,7 +360,7 @@
},
"source": [
"# run a local matmul on each device in parallel (no data transfer)\n",
"result = pmap(lambda x: np.dot(x, x.T))(mats)\n",
"result = pmap(lambda x: jnp.dot(x, x.T))(mats)\n",
"result.shape"
],
"execution_count": 0,
@ -375,7 +375,7 @@
},
"source": [
"# compute the mean on each device in parallel and print the results\n",
"print(pmap(np.mean)(result))"
"print(pmap(jnp.mean)(result))"
],
"execution_count": 0,
"outputs": []
@ -423,7 +423,7 @@
"from jax import lax\n",
"\n",
"normalize = lambda x: x / lax.psum(x, axis_name='i')\n",
"result = pmap(normalize, axis_name='i')(np.arange(4.))\n",
"result = pmap(normalize, axis_name='i')(jnp.arange(4.))\n",
"print(result)"
],
"execution_count": 0,
@ -455,7 +455,7 @@
"def normalize(x):\n",
" return x / lax.psum(x, 'i')\n",
"\n",
"print(normalize(np.arange(4.)))"
"print(normalize(jnp.arange(4.)))"
],
"execution_count": 0,
"outputs": []
@ -486,7 +486,7 @@
" doubly_normed = x / lax.psum(x, ('rows', 'cols'))\n",
" return row_normed, col_normed, doubly_normed\n",
"\n",
"x = np.arange(8.).reshape((4, 2))\n",
"x = jnp.arange(8.).reshape((4, 2))\n",
"a, b, c = f(x)\n",
"\n",
"print(a)\n",
@ -538,14 +538,14 @@
"def step(board_slice):\n",
" left, right = board_slice[:1], board_slice[-1:]\n",
" right, left = send_left(left, 'i'), send_right(right, 'i')\n",
" enlarged_board_slice = np.concatenate([left, board_slice, right])\n",
" enlarged_board_slice = jnp.concatenate([left, board_slice, right])\n",
" return update_board(enlarged_board_slice)\n",
"\n",
"def print_board(board):\n",
" print(''.join('*' if x else ' ' for x in board.ravel()))\n",
"\n",
"\n",
"board = onp.zeros(40, dtype=bool)\n",
"board = np.zeros(40, dtype=bool)\n",
"board[board.shape[0] // 2] = True\n",
"reshaped_board = board.reshape((device_count, -1))\n",
"\n",
@ -589,11 +589,11 @@
"\n",
"@pmap\n",
"def f(x):\n",
" y = np.sin(x)\n",
" y = jnp.sin(x)\n",
" @pmap\n",
" def g(z):\n",
" return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum()\n",
" return grad(lambda w: np.sum(g(w)))(x)\n",
" return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n",
" return grad(lambda w: jnp.sum(g(w)))(x)\n",
" \n",
"f(x)"
],
@ -608,7 +608,7 @@
"colab": {}
},
"source": [
"grad(lambda x: np.sum(f(x)))(x)"
"grad(lambda x: jnp.sum(f(x)))(x)"
],
"execution_count": 0,
"outputs": []

View File

@ -84,8 +84,8 @@
"from jax import jit, pmap\n",
"from jax import lax\n",
"from jax import tree_util\n",
"import jax.numpy as np\n",
"import numpy as onp\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import skimage.filters\n",
"import proglog\n",
@ -123,13 +123,13 @@
"def halo_exchange_padding(array, padding=1, axis=0, axis_name='x'):\n",
" if not padding > 0:\n",
" raise ValueError(f'invalid padding: {padding}')\n",
" array = np.array(array)\n",
" array = jnp.array(array)\n",
" if array.ndim == 0:\n",
" return array\n",
" left = slice_along_axis(array, slice(None, padding), axis)\n",
" right = slice_along_axis(array, slice(-padding, None), axis)\n",
" right, left = send_left(left, axis_name), send_right(right, axis_name)\n",
" return np.concatenate([left, array, right], axis)\n",
" return jnp.concatenate([left, array, right], axis)\n",
"\n",
"@tree_vectorize\n",
"def halo_exchange_inplace(array, padding=1, axis=0, axis_name='x'):\n",
@ -153,14 +153,14 @@
" new_shape = list(array.shape)\n",
" new_shape[split_axis] = tile_size\n",
" new_shape.insert(split_axis, num_splits)\n",
" return np.moveaxis(np.reshape(array, new_shape), split_axis, tile_id_axis)\n",
" return jnp.moveaxis(jnp.reshape(array, new_shape), split_axis, tile_id_axis)\n",
"\n",
"def stack_with_reshape(array, *, split_axis=0, tile_id_axis=None):\n",
" if tile_id_axis is None:\n",
" tile_id_axis = split_axis\n",
" array = np.moveaxis(array, tile_id_axis, split_axis)\n",
" array = jnp.moveaxis(array, tile_id_axis, split_axis)\n",
" new_shape = array.shape[:split_axis] + (-1,) + array.shape[split_axis+2:]\n",
" return np.reshape(array, new_shape)\n",
" return jnp.reshape(array, new_shape)\n",
"\n",
"def shard(func):\n",
" def wrapper(state):\n",
@ -178,7 +178,7 @@
" sliced = slice_along_axis(array, index, axis)\n",
" padding = [(0, 0)] * array.ndim\n",
" padding[axis] = (-min(offset, 0), max(offset, 0))\n",
" return np.pad(sliced, padding, mode='constant', constant_values=0)\n",
" return jnp.pad(sliced, padding, mode='constant', constant_values=0)\n",
"\n",
"def laplacian(array, step=1):\n",
" left = shift(array, +1, axis=0)\n",
@ -204,10 +204,10 @@
"\n",
"# Time stepping\n",
"\n",
"def multi_step(state, count, dt=1/np.sqrt(2), c=1):\n",
"def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):\n",
" return lax.fori_loop(0, count, lambda i, s: leapfrog_step(s, dt, c), state)\n",
"\n",
"def multi_step_pmap(state, count, dt=1/np.sqrt(2), c=1, exchange_interval=1,\n",
"def multi_step_pmap(state, count, dt=1/jnp.sqrt(2), c=1, exchange_interval=1,\n",
" save_interval=1):\n",
"\n",
" def exchange_and_multi_step(state_padded):\n",
@ -231,7 +231,7 @@
" tree_util.tree_map(lambda x: x.copy_to_host_async(), state)\n",
" results.append(state)\n",
" results = jax.device_get(results)\n",
" return tree_util.tree_multimap(lambda *xs: onp.stack([onp.array(x) for x in xs]), *results)\n",
" return tree_util.tree_multimap(lambda *xs: np.stack([np.array(x) for x in xs]), *results)\n",
"\n",
"multi_step_jit = jax.jit(multi_step)"
]
@ -256,9 +256,9 @@
},
"outputs": [],
"source": [
"x = np.linspace(0, 8, num=8*1024, endpoint=False)\n",
"y = np.linspace(0, 1, num=1*1024, endpoint=False)\n",
"x_mesh, y_mesh = np.meshgrid(x, y, indexing='ij')\n",
"x = jnp.linspace(0, 8, num=8*1024, endpoint=False)\n",
"y = jnp.linspace(0, 1, num=1*1024, endpoint=False)\n",
"x_mesh, y_mesh = jnp.meshgrid(x, y, indexing='ij')\n",
"\n",
"# NOTE: smooth initial conditions are important, so we aren't exciting\n",
"# arbitrarily high frequencies (that cannot be resolved)\n",
@ -266,13 +266,13 @@
" ((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) < 0.1 ** 2,\n",
" sigma=1)\n",
"\n",
"# u = np.exp(-((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) / 0.1 ** 2)\n",
"# u = jnp.exp(-((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) / 0.1 ** 2)\n",
"\n",
"# u = skimage.filters.gaussian(\n",
"# (x_mesh > 1/3) & (x_mesh < 1/2) & (y_mesh > 1/3) & (y_mesh < 1/2),\n",
"# sigma=5)\n",
"\n",
"v = np.zeros_like(u)\n",
"v = jnp.zeros_like(u)\n",
"c = 1 # could also use a 2D array matching the mesh shape"
]
},
@ -445,7 +445,7 @@
" images = []\n",
" for frame in data:\n",
" if vmax is None:\n",
" this_vmax = onp.max(abs(frame))\n",
" this_vmax = np.max(abs(frame))\n",
" else:\n",
" this_vmax = vmax\n",
" norm = matplotlib.colors.Normalize(vmin=-this_vmax, vmax=this_vmax)\n",
@ -474,7 +474,7 @@
"source": [
"# Show Movie\n",
"proglog.default_bar_logger = partial(proglog.default_bar_logger, None)\n",
"ImageSequenceClip([onp.array(im) for im in images], fps=25).ipython_display()"
"ImageSequenceClip([np.array(im) for im in images], fps=25).ipython_display()"
]
},
{

View File

@ -4,11 +4,11 @@ Asynchronous dispatch
JAX uses asynchronous dispatch to hide Python overheads. Consider the following
program:
>>> import numpy as onp
>>> from jax import numpy as np
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
>>> np.dot(x, x) + 3. # doctest: +SKIP
>>> jnp.dot(x, x) + 3. # doctest: +SKIP
DeviceArray([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
@ -23,7 +23,7 @@ DeviceArray([[258.01971436, 249.64862061, 257.13372803, ...,
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
248.62597656, 243.22348022]], dtype=float32)
When an operation such as :code:`np.dot(x, x)` is executed, JAX does not wait
When an operation such as :code:`jnp.dot(x, x)` is executed, JAX does not wait
for the operation to complete before returning control to the Python program.
Instead, JAX returns a :class:`~jax.DeviceArray` value, which is a future,
i.e., a value that will be produced in the future on an accelerator device but
@ -44,7 +44,7 @@ arbitrary amounts of work and avoid having the accelerator wait.
Asynchronous dispatch has a slightly surprising consequence for microbenchmarks.
>>> %time np.dot(x, x) # doctest: +SKIP
>>> %time jnp.dot(x, x) # doctest: +SKIP
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
@ -70,7 +70,7 @@ use the :meth:`~jaxDeviceArray.block_until_ready` method on a
:class:`DeviceArray` value to wait for the computation that produced it to
complete.
>>> %time onp.asarray(np.dot(x, x)) # doctest: +SKIP
>>> %time np.asarray(jnp.dot(x, x)) # doctest: +SKIP
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
@ -87,7 +87,7 @@ array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time np.dot(x, x).block_until_ready() # doctest: +SKIP
>>> %time jnp.dot(x, x).block_until_ready() # doctest: +SKIP
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,

View File

@ -86,7 +86,7 @@ The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
For example, here is the jaxpr produced for the function ``func1`` below
>>> from jax import make_jaxpr
>>> from jax import numpy as jnp
>>> import jax.numpy as jnp
>>> def func1(first, second):
... temp = first + jnp.sin(second) * 3.
... return jnp.sum(temp)
@ -187,7 +187,7 @@ JAX produces the following jaxpr
in (e,) }
When tracing ``func6``, the function ``func5`` is invoked with a constant value
(``onp.ones(8)``) for the second argument. As a result, the sub-expression
(``np.ones(8)``) for the second argument. As a result, the sub-expression
``jnp.sin(second) * 3.`` is constant-folded.
There are two ConstVars, ``b`` (standing for ``jnp.sin(second) * 3.``) and ``d``
(standing for ``jnp.ones(8)``). Unfortunately, it is not easy to tell from the
@ -348,7 +348,7 @@ and :py:func:`jax.lax.fori_loop`
In the above signature, “C” stands for the type of a the loop “carry” value.
For example, here is an example fori loop
>>> import numpy as onp
>>> import numpy as np
>>>
>>> def func10(arg, n):
... ones = jnp.ones(arg.shape) # A constant
@ -356,7 +356,7 @@ For example, here is an example fori loop
... lambda i, carry: carry + ones * 3. + arg,
... arg + ones)
...
>>> print(make_jaxpr(func10)(onp.ones(16), 5))
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda c d ; a b.
let e = add a d
_ _ f = while[ body_jaxpr={ lambda ; e g a b c.
@ -414,7 +414,7 @@ For the example consider the function ``func11`` below
... return (carry + ae1 * ae2 + extra, carry)
... return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(onp.ones(16), 5.))
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda c ; a b.
let d e = scan[ jaxpr={ lambda ; f a b c.
let d = mul b c

View File

@ -35,12 +35,12 @@
},
"outputs": [],
"source": [
"import numpy as onp\n",
"import numpy as np\n",
"from jax import grad, jit\n",
"from jax import lax\n",
"from jax import random\n",
"import jax\n",
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"import matplotlib as mpl\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib import rcParams\n",
@ -117,7 +117,7 @@
"print (\"Second call: \", jit(impure_print_side_effect)(5.))\n",
"\n",
"# JAX re-runs the Python function when the type or shape of the argument changes\n",
"print (\"Third call, different type: \", jit(impure_print_side_effect)(np.array([5.])))"
"print (\"Third call, different type: \", jit(impure_print_side_effect)(jnp.array([5.])))"
]
},
{
@ -157,7 +157,7 @@
"\n",
"# JAX re-runs the Python function when the type or shape of the argument changes\n",
"# This will end up reading the latest value of the global\n",
"print (\"Third call, different type: \", jit(impure_uses_globals)(np.array([4.])))"
"print (\"Third call, different type: \", jit(impure_uses_globals)(jnp.array([4.])))"
]
},
{
@ -334,7 +334,7 @@
}
],
"source": [
"numpy_array = onp.zeros((3,3), dtype=np.float32)\n",
"numpy_array = np.zeros((3,3), dtype=np.float32)\n",
"print(\"original array:\")\n",
"print(numpy_array)\n",
"\n",
@ -379,7 +379,7 @@
}
],
"source": [
"jax_array = np.zeros((3,3), dtype=np.float32)\n",
"jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n",
"\n",
"# In place update of JAX's array will yield an error!\n",
"try:\n",
@ -470,7 +470,7 @@
}
],
"source": [
"jax_array = np.zeros((3, 3))\n",
"jax_array = jnp.zeros((3, 3))\n",
"print(\"original array:\")\n",
"print(jax_array)\n",
"\n",
@ -537,7 +537,7 @@
],
"source": [
"print(\"original array:\")\n",
"jax_array = np.ones((5, 6))\n",
"jax_array = jnp.ones((5, 6))\n",
"print(jax_array)\n",
"\n",
"new_jax_array = index_add(jax_array, index[::2, 3:], 7.)\n",
@ -591,7 +591,7 @@
],
"source": [
"try:\n",
" onp.arange(10)[11]\n",
" np.arange(10)[11]\n",
"except Exception as e:\n",
" print(\"Exception {}\".format(e))"
]
@ -633,14 +633,14 @@
}
],
"source": [
"np.arange(10)[11]"
"jnp.arange(10)[11]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that due to this behavior np.nanargmin and np.nanargmax return -1 for slices consisting of NaNs whereas Numpy would throw an error."
"Note that due to this behavior jnp.nanargmin and jnp.nanargmax return -1 for slices consisting of NaNs whereas Numpy would throw an error."
]
},
{
@ -700,9 +700,9 @@
}
],
"source": [
"print(onp.random.random())\n",
"print(onp.random.random())\n",
"print(onp.random.random())"
"print(np.random.random())\n",
"print(np.random.random())\n",
"print(np.random.random())"
]
},
{
@ -725,8 +725,8 @@
},
"outputs": [],
"source": [
"onp.random.seed(0)\n",
"rng_state = onp.random.get_state()\n",
"np.random.seed(0)\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state)\n",
"# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n",
"# 2481403966, 4042607538, 337614300, ... 614 more numbers..., \n",
@ -753,23 +753,23 @@
},
"outputs": [],
"source": [
"_ = onp.random.uniform()\n",
"rng_state = onp.random.get_state()\n",
"_ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state) \n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n",
"\n",
"# Let's exhaust the entropy in this PRNG statevector\n",
"for i in range(311):\n",
" _ = onp.random.uniform()\n",
"rng_state = onp.random.get_state()\n",
" _ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state) \n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n",
"\n",
"# Next call iterates the RNG state for a new batch of fake \"entropy\".\n",
"_ = onp.random.uniform()\n",
"rng_state = onp.random.get_state()\n",
"_ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"# print(rng_state) \n",
"# --> ('MT19937', array([1499117434, 2949980591, 2242547484, \n",
"# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)"
@ -1146,7 +1146,7 @@
" y = y + x[i]\n",
" return y\n",
"\n",
"print(g(np.array([1., 2., 3.])))"
"print(g(jnp.array([1., 2., 3.])))"
]
},
{
@ -1206,13 +1206,13 @@
"\n",
"When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n",
"\n",
"For example, if we evaluate an `@jit` function on the array `np.array([1., 2., 3.], np.float32)`, we might want to compile code that we can reuse to evaluate the function on `np.array([4., 5., 6.], np.float32)` to save on compile time.\n",
"For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
"\n",
"To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/master/jax/abstract_arrays.py), and different transformations use different abstraction levels.\n",
"\n",
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), np.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
"\n",
"But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), np.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), np.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n",
"But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n",
"\n",
"The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:"
]
@ -1295,7 +1295,7 @@
"\n",
"f = jit(f, static_argnums=(1,))\n",
"\n",
"f(np.array([2., 3., 4.]), 2)"
"f(jnp.array([2., 3., 4.]), 2)"
]
},
{
@ -1347,7 +1347,7 @@
],
"source": [
"def example_fun(length, val):\n",
" return np.ones((length,)) * val\n",
" return jnp.ones((length,)) * val\n",
"# un-jit'd works fine\n",
"print(example_fun(5, 4))\n",
"\n",
@ -1487,7 +1487,7 @@
"source": [
"from jax import lax\n",
"\n",
"operand = np.array([0.])\n",
"operand = jnp.array([0.])\n",
"lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
"# --> array([1.], dtype=float32)\n",
"lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
@ -1690,10 +1690,10 @@
],
"source": [
"# 2D kernel - HWIO layout\n",
"kernel = onp.zeros((3, 3, 3, 3), dtype=np.float32)\n",
"kernel += onp.array([[1, 1, 0],\n",
"kernel = np.zeros((3, 3, 3, 3), dtype=jnp.float32)\n",
"kernel += np.array([[1, 1, 0],\n",
" [1, 0,-1],\n",
" [0,-1,-1]])[:, :, onp.newaxis, onp.newaxis]\n",
" [0,-1,-1]])[:, :, np.newaxis, np.newaxis]\n",
"\n",
"print(\"Edge Conv kernel:\")\n",
"plt.imshow(kernel[:, :, 0, 0]);"
@ -1744,7 +1744,7 @@
],
"source": [
"# NHWC layout\n",
"img = onp.zeros((1, 200, 198, 3), dtype=np.float32)\n",
"img = np.zeros((1, 200, 198, 3), dtype=jnp.float32)\n",
"for k in range(3):\n",
" x = 30 + 60*k\n",
" y = 20 + 60*k\n",
@ -1811,14 +1811,14 @@
}
],
"source": [
"out = lax.conv(np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
" np.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor\n",
"out = lax.conv(jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
" jnp.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor\n",
" (1, 1), # window strides\n",
" 'SAME') # padding mode\n",
"print(\"out shape: \", out.shape)\n",
"print(\"First output channel:\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(onp.array(out)[0,0,:,:]);"
"plt.imshow(np.array(out)[0,0,:,:]);"
]
},
{
@ -1857,8 +1857,8 @@
],
"source": [
"out = lax.conv_with_general_padding(\n",
" np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
" np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor\n",
" jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
" jnp.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor\n",
" (1, 1), # window strides\n",
" ((2,2),(2,2)), # general padding 2x2\n",
" (1,1), # lhs/image dilation\n",
@ -1866,7 +1866,7 @@
"print(\"out shape: \", out.shape)\n",
"print(\"First output channel:\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(onp.array(out)[0,0,:,:]);"
"plt.imshow(np.array(out)[0,0,:,:]);"
]
},
{
@ -1973,7 +1973,7 @@
"print(\"out shape: \", out.shape)\n",
"print(\"First output channel:\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
"plt.imshow(np.array(out)[0,:,:,0]);"
]
},
{
@ -2031,7 +2031,7 @@
"print(\"out shape: \", out.shape, \"DIFFERENT from above!\")\n",
"print(\"First output channel:\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
"plt.imshow(np.array(out)[0,:,:,0]);"
]
},
{
@ -2089,7 +2089,7 @@
"print(\"out shape: \", out.shape, \" <-- half the size of above\")\n",
"plt.figure(figsize=(10,10))\n",
"print(\"First output channel:\")\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
"plt.imshow(np.array(out)[0,:,:,0]);"
]
},
{
@ -2147,7 +2147,7 @@
"print(\"out shape: \", out.shape)\n",
"plt.figure(figsize=(10,10))\n",
"print(\"First output channel:\")\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
"plt.imshow(np.array(out)[0,:,:,0]);"
]
},
{
@ -2205,7 +2205,7 @@
"print(\"out shape: \", out.shape, \"<-- larger than original!\")\n",
"plt.figure(figsize=(10,10))\n",
"print(\"First output channel:\")\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
"plt.imshow(np.array(out)[0,:,:,0]);"
]
},
{
@ -2259,7 +2259,7 @@
"\n",
"# transposed conv = 180deg kernel roation plus LHS dilation\n",
"# rotate kernel 180deg:\n",
"kernel_rot = np.rot90(np.rot90(kernel, axes=(0,1)), axes=(0,1))\n",
"kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))\n",
"# need a custom output padding:\n",
"padding = ((2, 1), (2, 1))\n",
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
@ -2272,7 +2272,7 @@
"print(\"out shape: \", out.shape, \"<-- transposed_conv\")\n",
"plt.figure(figsize=(10,10))\n",
"print(\"First output channel:\")\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
"plt.imshow(np.array(out)[0,:,:,0]);"
]
},
{
@ -2344,11 +2344,11 @@
],
"source": [
"# 1D kernel - WIO layout\n",
"kernel = onp.array([[[1, 0, -1], [-1, 0, 1]], \n",
"kernel = np.array([[[1, 0, -1], [-1, 0, 1]], \n",
" [[1, 1, 1], [-1, -1, -1]]], \n",
" dtype=np.float32).transpose([2,1,0])\n",
" dtype=jnp.float32).transpose([2,1,0])\n",
"# 1D data - NWC layout\n",
"data = onp.zeros((1, 200, 2), dtype=np.float32)\n",
"data = np.zeros((1, 200, 2), dtype=jnp.float32)\n",
"for i in range(2):\n",
" for k in range(2):\n",
" x = 35*i + 30 + 60*k\n",
@ -2433,16 +2433,16 @@
],
"source": [
"# Random 3D kernel - HWDIO layout\n",
"kernel = onp.array([\n",
"kernel = np.array([\n",
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n",
" [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n",
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n",
" dtype=np.float32)[:, :, :, onp.newaxis, onp.newaxis]\n",
" dtype=jnp.float32)[:, :, :, np.newaxis, np.newaxis]\n",
"\n",
"# 3D data - NHWDC layout\n",
"data = onp.zeros((1, 30, 30, 30, 1), dtype=np.float32)\n",
"x, y, z = onp.mgrid[0:1:30j, 0:1:30j, 0:1:30j]\n",
"data += (onp.sin(2*x*np.pi)*onp.cos(2*y*np.pi)*onp.cos(2*z*np.pi))[None,:,:,:,None]\n",
"data = np.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)\n",
"x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]\n",
"data += (np.sin(2*x*jnp.pi)*np.cos(2*y*jnp.pi)*np.cos(2*z*jnp.pi))[None,:,:,:,None]\n",
"\n",
"print(\"in shapes:\", data.shape, kernel.shape)\n",
"dn = lax.conv_dimension_numbers(data.shape, kernel.shape,\n",
@ -2461,8 +2461,8 @@
"# Make some simple 3d density plots:\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"def make_alpha(cmap):\n",
" my_cmap = cmap(np.arange(cmap.N))\n",
" my_cmap[:,-1] = np.linspace(0, 1, cmap.N)**3\n",
" my_cmap = cmap(jnp.arange(cmap.N))\n",
" my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3\n",
" return mpl.colors.ListedColormap(my_cmap)\n",
"my_cmap = make_alpha(plt.cm.viridis)\n",
"fig = plt.figure()\n",
@ -2516,13 +2516,13 @@
"metadata": {},
"source": [
"```\n",
"In [1]: import jax.numpy as np\n",
"In [1]: import jax.numpy as jnp\n",
"\n",
"In [2]: np.divide(0., 0.)\n",
"In [2]: jnp.divide(0., 0.)\n",
"---------------------------------------------------------------------------\n",
"FloatingPointError Traceback (most recent call last)\n",
"<ipython-input-2-f2e2c413b437> in <module>()\n",
"----> 1 np.divide(0., 0.)\n",
"----> 1 jnp.divide(0., 0.)\n",
"\n",
".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n",
" 343 return floor_divide(x1, x2)\n",
@ -2549,7 +2549,7 @@
"\n",
".../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)\n",
" 103 py_val = device_buffer.to_py()\n",
" 104 if onp.any(onp.isnan(py_val)):\n",
" 104 if np.any(np.isnan(py_val)):\n",
"--> 105 raise FloatingPointError(\"invalid value\")\n",
" 106 else:\n",
" 107 return DeviceArray(device_buffer, *result_shape)\n",
@ -2580,9 +2580,9 @@
" ...: return a + b * c\n",
" ...:\n",
"\n",
"In [6]: x = np.array([2., 0.])\n",
"In [6]: x = jnp.array([2., 0.])\n",
"\n",
"In [7]: y = np.array([3., 0.])\n",
"In [7]: y = jnp.array([3., 0.])\n",
"\n",
"In [8]: f(x, y)\n",
"Invalid value encountered in the output of a jit function. Calling the de-optimized version.\n",
@ -2673,7 +2673,7 @@
}
],
"source": [
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)\n",
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
"x.dtype"
]
},
@ -2746,8 +2746,9 @@
}
],
"source": [
"from jax import numpy as np, random\n",
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
"x.dtype # --> dtype('float64')"
]
},
@ -2803,4 +2804,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
}
}

View File

@ -64,19 +64,19 @@
"colab": {}
},
"source": [
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import custom_jvp\n",
"\n",
"@custom_jvp\n",
"def f(x, y):\n",
" return np.sin(x) * y\n",
" return jnp.sin(x) * y\n",
"\n",
"@f.defjvp\n",
"def f_jvp(primals, tangents):\n",
" x, y = primals\n",
" x_dot, y_dot = tangents\n",
" primal_out = f(x, y)\n",
" tangent_out = np.cos(x) * x_dot * y + np.sin(x) * y_dot\n",
" tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot\n",
" return primal_out, tangent_out"
],
"execution_count": 0,
@ -128,10 +128,10 @@
"\n",
"@custom_jvp\n",
"def f(x, y):\n",
" return np.sin(x) * y\n",
" return jnp.sin(x) * y\n",
"\n",
"f.defjvps(lambda x_dot, primal_out, x, y: np.cos(x) * x_dot * y,\n",
" lambda y_dot, primal_out, x, y: np.sin(x) * y_dot)"
"f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,\n",
" lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)"
],
"execution_count": 0,
"outputs": []
@ -190,10 +190,10 @@
"\n",
"@custom_vjp\n",
"def f(x, y):\n",
" return np.sin(x) * y\n",
" return jnp.sin(x) * y\n",
"\n",
"def f_fwd(x, y):\n",
" return f(x, y), (np.cos(x), np.sin(x), y)\n",
" return f(x, y), (jnp.cos(x), jnp.sin(x), y)\n",
"\n",
"def f_bwd(res, g):\n",
" cos_x, sin_x, y = res\n",
@ -279,10 +279,10 @@
}
},
"source": [
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"\n",
"def log1pexp(x):\n",
" return np.log(1. + np.exp(x))\n",
" return jnp.log(1. + jnp.exp(x))\n",
"\n",
"log1pexp(3.)"
],
@ -328,7 +328,7 @@
"\n",
"print(jit(log1pexp)(3.))\n",
"print(jit(grad(log1pexp))(3.))\n",
"print(vmap(jit(grad(log1pexp)))(np.arange(3.)))"
"print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))"
],
"execution_count": 8,
"outputs": [
@ -434,7 +434,7 @@
"colab_type": "text"
},
"source": [
"Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and $\\infty$, respectively, which is never a good idea. That is, we're effectively evaluating `lambda x: (1 / (1 + np.exp(x))) * np.exp(x)` for large `x`, which effectively turns into `0. * np.inf`.\n",
"Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and $\\infty$, respectively, which is never a good idea. That is, we're effectively evaluating `lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)` for large `x`, which effectively turns into `0. * jnp.inf`.\n",
"\n",
"Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \\frac{1}{1 + e^x}$, with no cancellation in sight.\n",
"\n",
@ -457,14 +457,14 @@
"\n",
"@custom_jvp\n",
"def log1pexp(x):\n",
" return np.log(1. + np.exp(x))\n",
" return jnp.log(1. + jnp.exp(x))\n",
"\n",
"@log1pexp.defjvp\n",
"def log1pexp_jvp(primals, tangents):\n",
" x, = primals\n",
" x_dot, = tangents\n",
" ans = log1pexp(x)\n",
" ans_dot = (1 - 1/(1 + np.exp(x))) * x_dot\n",
" ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot\n",
" return ans, ans_dot"
],
"execution_count": 0,
@ -509,7 +509,7 @@
"source": [
"print(jit(log1pexp)(3.))\n",
"print(jit(grad(log1pexp))(3.))\n",
"print(vmap(jit(grad(log1pexp)))(np.arange(3.)))"
"print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))"
],
"execution_count": 13,
"outputs": [
@ -544,9 +544,9 @@
"source": [
"@custom_jvp\n",
"def log1pexp(x):\n",
" return np.log(1. + np.exp(x))\n",
" return jnp.log(1. + jnp.exp(x))\n",
"\n",
"log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + np.exp(x))) * t)"
"log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)"
],
"execution_count": 0,
"outputs": []
@ -566,7 +566,7 @@
"print(grad(log1pexp)(100.))\n",
"print(jit(log1pexp)(3.))\n",
"print(jit(grad(log1pexp))(3.))\n",
"print(vmap(jit(grad(log1pexp)))(np.arange(3.)))"
"print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))"
],
"execution_count": 15,
"outputs": [
@ -614,7 +614,7 @@
},
"source": [
"def f(x):\n",
" return x / (1 + np.sqrt(x))"
" return x / (1 + jnp.sqrt(x))"
],
"execution_count": 0,
"outputs": []
@ -676,14 +676,14 @@
"source": [
"@custom_jvp\n",
"def f(x):\n",
" return x / (1 + np.sqrt(x))\n",
" return x / (1 + jnp.sqrt(x))\n",
"\n",
"@f.defjvp\n",
"def f_jvp(primals, tangents):\n",
" x, = primals\n",
" x_dot, = tangents\n",
" ans = f(x)\n",
" ans_dot = ((np.sqrt(x) + 2) / (2 * (np.sqrt(x) + 1)**2)) * x_dot\n",
" ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot\n",
" return ans, ans_dot"
],
"execution_count": 0,
@ -734,9 +734,9 @@
"source": [
"@custom_jvp\n",
"def f(x):\n",
" return x / (1 + np.sqrt(x))\n",
" return x / (1 + jnp.sqrt(x))\n",
"\n",
"f.defjvps(lambda t, ans, x: ((np.sqrt(x) + 2) / (2 * (np.sqrt(x) + 1)**2)) * t)"
"f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)"
],
"execution_count": 0,
"outputs": []
@ -777,7 +777,7 @@
"\n",
"While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.\n",
"\n",
"For gradient clipping, we can use `np.clip` together with a `jax.custom_vjp` reverse-mode-only rule:"
"For gradient clipping, we can use `jnp.clip` together with a `jax.custom_vjp` reverse-mode-only rule:"
]
},
{
@ -799,7 +799,7 @@
" return x, None # no residual values to save\n",
"\n",
"def clip_gradient_bwd(lo, hi, _, g):\n",
" return (np.clip(g, lo, hi),)\n",
" return (jnp.clip(g, lo, hi),)\n",
"\n",
"clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)"
],
@ -821,10 +821,10 @@
"import matplotlib.pyplot as plt\n",
"from jax import vmap\n",
"\n",
"t = np.linspace(0, 10, 1000)\n",
"t = jnp.linspace(0, 10, 1000)\n",
"\n",
"plt.plot(np.sin(t))\n",
"plt.plot(vmap(grad(np.sin))(t))"
"plt.plot(jnp.sin(t))\n",
"plt.plot(vmap(grad(jnp.sin))(t))"
],
"execution_count": 23,
"outputs": [
@ -868,7 +868,7 @@
"source": [
"def clip_sin(x):\n",
" x = clip_gradient(-0.75, 0.75, x)\n",
" return np.sin(x)\n",
" return jnp.sin(x)\n",
"\n",
"plt.plot(clip_sin(t))\n",
"plt.plot(vmap(grad(clip_sin))(t))"
@ -963,7 +963,7 @@
"def fixed_point(f, a, x_guess):\n",
" def cond_fun(carry):\n",
" x_prev, x = carry\n",
" return np.abs(x_prev - x) > 1e-6\n",
" return jnp.abs(x_prev - x) > 1e-6\n",
"\n",
" def body_fun(carry):\n",
" _, x = carry\n",
@ -1049,7 +1049,7 @@
}
},
"source": [
"print(jit(vmap(newton_sqrt))(np.array([1., 2., 3., 4.])))"
"print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))"
],
"execution_count": 28,
"outputs": [
@ -1108,7 +1108,7 @@
"def fixed_point(f, a, x_guess):\n",
" def cond_fun(carry):\n",
" x_prev, x = carry\n",
" return np.abs(x_prev - x) > 1e-6\n",
" return jnp.abs(x_prev - x) > 1e-6\n",
"\n",
" def body_fun(carry):\n",
" _, x = carry\n",
@ -1127,7 +1127,7 @@
" a_bar, = vjp_a(fixed_point(partial(rev_iter, f),\n",
" (a, x_star, x_star_bar),\n",
" x_star_bar))\n",
" return a_bar, np.zeros_like(x_star)\n",
" return a_bar, jnp.zeros_like(x_star)\n",
" \n",
"def rev_iter(f, packed, u):\n",
" a, x_star, x_star_bar = packed\n",
@ -1198,7 +1198,7 @@
"colab_type": "text"
},
"source": [
"We can check our answers by differentiating `np.sqrt`, which uses a totally different implementation:"
"We can check our answers by differentiating `jnp.sqrt`, which uses a totally different implementation:"
]
},
{
@ -1213,8 +1213,8 @@
}
},
"source": [
"print(grad(np.sqrt)(2.))\n",
"print(grad(grad(np.sqrt))(2.))"
"print(grad(jnp.sqrt)(2.))\n",
"print(grad(grad(jnp.sqrt))(2.))"
],
"execution_count": 32,
"outputs": [
@ -1270,18 +1270,18 @@
},
"source": [
"from jax import custom_jvp\n",
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"\n",
"# f :: a -> b\n",
"@custom_jvp\n",
"def f(x):\n",
" return np.sin(x)\n",
" return jnp.sin(x)\n",
"\n",
"# f_jvp :: (a, T a) -> (b, T b)\n",
"def f_jvp(primals, tangents):\n",
" x, = primals\n",
" t, = tangents\n",
" return f(x), np.cos(x) * t\n",
" return f(x), jnp.cos(x) * t\n",
"\n",
"f.defjvp(f_jvp)"
],
@ -1479,9 +1479,9 @@
"source": [
"@custom_jvp\n",
"def f(x):\n",
" return np.sin(x)\n",
" return jnp.sin(x)\n",
"\n",
"f.defjvps(lambda t, ans, x: np.cos(x) * t)"
"f.defjvps(lambda t, ans, x: jnp.cos(x) * t)"
],
"execution_count": 0,
"outputs": []
@ -1656,14 +1656,14 @@
"@custom_jvp\n",
"def f(x):\n",
" print('called f!') # a harmless side-effect\n",
" return np.sin(x)\n",
" return jnp.sin(x)\n",
"\n",
"@f.defjvp\n",
"def f_jvp(primals, tangents):\n",
" print('called f_jvp!') # a harmless side-effect\n",
" x, = primals\n",
" t, = tangents\n",
" return f(x), np.cos(x) * t"
" return f(x), jnp.cos(x) * t"
],
"execution_count": 0,
"outputs": []
@ -1708,7 +1708,7 @@
}
},
"source": [
"print(vmap(f)(np.arange(3.)))\n",
"print(vmap(f)(jnp.arange(3.)))\n",
"print(jit(f)(3.))"
],
"execution_count": 46,
@ -1860,9 +1860,9 @@
"@custom_jvp\n",
"def f(x):\n",
" if x > 0:\n",
" return np.sin(x)\n",
" return jnp.sin(x)\n",
" else:\n",
" return np.cos(x)\n",
" return jnp.cos(x)\n",
"\n",
"@f.defjvp\n",
"def f_jvp(primals, tangents):\n",
@ -1925,16 +1925,16 @@
},
"source": [
"from jax import custom_vjp\n",
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"\n",
"# f :: a -> b\n",
"@custom_vjp\n",
"def f(x):\n",
" return np.sin(x)\n",
" return jnp.sin(x)\n",
"\n",
"# f_fwd :: a -> (b, c)\n",
"def f_fwd(x):\n",
" return f(x), np.cos(x)\n",
" return f(x), jnp.cos(x)\n",
"\n",
"# f_bwd :: (c, CT b) -> CT a\n",
"def f_bwd(cos_x, y_bar):\n",
@ -2010,10 +2010,10 @@
"\n",
"@custom_vjp\n",
"def f(x, y):\n",
" return np.sin(x) * y\n",
" return jnp.sin(x) * y\n",
"\n",
"def f_fwd(x, y):\n",
" return f(x, y), (np.cos(x), np.sin(x), y)\n",
" return f(x, y), (jnp.cos(x), jnp.sin(x), y)\n",
"\n",
"def f_bwd(res, g):\n",
" cos_x, sin_x, y = res\n",
@ -2080,11 +2080,11 @@
"@custom_vjp\n",
"def f(x):\n",
" print(\"called f!\")\n",
" return np.sin(x)\n",
" return jnp.sin(x)\n",
"\n",
"def f_fwd(x):\n",
" print(\"called f_fwd!\")\n",
" return f(x), np.cos(x)\n",
" return f(x), jnp.cos(x)\n",
"\n",
"def f_bwd(cos_x, y_bar):\n",
" print(\"called f_bwd!\")\n",
@ -2304,7 +2304,7 @@
"def foo(x):\n",
" y = x ** 2\n",
" y = debug(y) # insert pdb in corresponding backward pass step\n",
" return np.sin(y)"
" return jnp.sin(y)"
],
"execution_count": 0,
"outputs": []
@ -2368,7 +2368,7 @@
"def f(pt):\n",
" x, y = pt.x, pt.y\n",
" return {'a': x ** 2,\n",
" 'b': (np.sin(x), np.cos(y))}\n",
" 'b': (jnp.sin(x), jnp.cos(y))}\n",
"\n",
"@f.defjvp\n",
"def f_jvp(primals, tangents):\n",
@ -2376,7 +2376,7 @@
" pt_dot, = tangents\n",
" ans = f(pt)\n",
" ans_dot = {'a': 2 * pt.x * pt_dot.x,\n",
" 'b': (np.cos(pt.x) * pt_dot.x, -np.sin(pt.y) * pt_dot.y)}\n",
" 'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}\n",
" return ans, ans_dot\n",
"\n",
"def fun(pt):\n",
@ -2460,15 +2460,15 @@
"def f(pt):\n",
" x, y = pt.x, pt.y\n",
" return {'a': x ** 2,\n",
" 'b': (np.sin(x), np.cos(y))}\n",
" 'b': (jnp.sin(x), jnp.cos(y))}\n",
"\n",
"def f_fwd(pt):\n",
" return f(pt), pt\n",
"\n",
"def f_bwd(pt, g):\n",
" a_bar, (b0_bar, b1_bar) = g['a'], g['b']\n",
" x_bar = 2 * pt.x * a_bar + np.cos(pt.x) * b0_bar\n",
" y_bar = -np.sin(pt.y) * b1_bar\n",
" x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar\n",
" y_bar = -jnp.sin(pt.y) * b1_bar\n",
" return (Point(x_bar, y_bar),)\n",
"\n",
"f.defvjp(f_fwd, f_bwd)\n",

View File

@ -239,7 +239,7 @@
},
"source": [
"import jax.numpy as jnp\n",
"import numpy as onp\n",
"import numpy as np\n",
"\n",
"@trace(\"multiply_add_numpy\")\n",
"def multiply_add_numpy(x, y, z):\n",
@ -429,7 +429,7 @@
" the concrete result of the primitive.\n",
" \"\"\"\n",
" # Note that we can use the original numpy, which is not JAX traceable\n",
" return onp.add(onp.multiply(x, y), z)\n",
" return np.add(np.multiply(x, y), z)\n",
"\n",
"# Now we register the primal implementation with JAX\n",
"multiply_add_p.def_impl(multiply_add_impl)"
@ -1154,7 +1154,7 @@
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 340, in grad_f\n",
" _, g = value_and_grad_f(*args, **kwargs)\n",
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 398, in value_and_grad_f\n",
" g = vjp_py(onp.ones((), dtype=dtype))\n",
" g = vjp_py(np.ones((), dtype=dtype))\n",
"NotImplementedError: Reverse-mode differentiation rule for 'multiply_add' not implemented\n"
],
"name": "stderr"
@ -1473,8 +1473,8 @@
"source": [
"# The arguments are two vectors instead of two scalars\n",
"with expectNotImplementedError():\n",
" api.vmap(square_add_prim, in_axes=0, out_axes=0)(onp.array([2., 3.]),\n",
" onp.array([10., 20.]))"
" api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),\n",
" np.array([10., 20.]))"
],
"execution_count": 22,
"outputs": [
@ -1500,7 +1500,7 @@
"\n",
"Traceback (most recent call last):\n",
" File \"<ipython-input-22-70154d0e2ab6>\", line 3, in <module>\n",
" onp.array([10., 20.]))\n",
" np.array([10., 20.]))\n",
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n",
" lambda: _flatten_axes(out_tree(), out_axes))\n",
" File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n",
@ -1574,9 +1574,9 @@
}
},
"source": [
"assert onp.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(\n",
" onp.array([2., 3.]),\n",
" onp.array([10., 20.])),\n",
"assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(\n",
" np.array([2., 3.]),\n",
" np.array([10., 20.])),\n",
" [14., 29.])"
],
"execution_count": 24,
@ -1622,9 +1622,9 @@
}
},
"source": [
"assert onp.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))\n",
" (onp.array([2., 3.]),\n",
" onp.array([10., 20.])),\n",
"assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))\n",
" (np.array([2., 3.]),\n",
" np.array([10., 20.])),\n",
" [14., 29.])"
],
"execution_count": 25,

View File

@ -47,7 +47,7 @@
},
"outputs": [],
"source": [
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import grad, jit, vmap\n",
"from jax import random"
]
@ -118,17 +118,17 @@
"from jax.scipy.special import logsumexp\n",
"\n",
"def relu(x):\n",
" return np.maximum(0, x)\n",
" return jnp.maximum(0, x)\n",
"\n",
"def predict(params, image):\n",
" # per-example predictions\n",
" activations = image\n",
" for w, b in params[:-1]:\n",
" outputs = np.dot(w, activations) + b\n",
" outputs = jnp.dot(w, activations) + b\n",
" activations = relu(outputs)\n",
" \n",
" final_w, final_b = params[-1]\n",
" logits = np.dot(final_w, activations) + final_b\n",
" logits = jnp.dot(final_w, activations) + final_b\n",
" return logits - logsumexp(logits)"
]
},
@ -262,18 +262,18 @@
},
"outputs": [],
"source": [
"def one_hot(x, k, dtype=np.float32):\n",
"def one_hot(x, k, dtype=jnp.float32):\n",
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
" return np.array(x[:, None] == np.arange(k), dtype)\n",
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
" \n",
"def accuracy(params, images, targets):\n",
" target_class = np.argmax(targets, axis=1)\n",
" predicted_class = np.argmax(batched_predict(params, images), axis=1)\n",
" return np.mean(predicted_class == target_class)\n",
" target_class = jnp.argmax(targets, axis=1)\n",
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",
" return jnp.mean(predicted_class == target_class)\n",
"\n",
"def loss(params, images, targets):\n",
" preds = batched_predict(params, images)\n",
" return -np.mean(preds * targets)\n",
" return -jnp.mean(preds * targets)\n",
"\n",
"@jit\n",
"def update(params, x, y):\n",
@ -334,18 +334,18 @@
},
"outputs": [],
"source": [
"import numpy as onp\n",
"import numpy as np\n",
"from torch.utils import data\n",
"from torchvision.datasets import MNIST\n",
"\n",
"def numpy_collate(batch):\n",
" if isinstance(batch[0], onp.ndarray):\n",
" return onp.stack(batch)\n",
" if isinstance(batch[0], np.ndarray):\n",
" return np.stack(batch)\n",
" elif isinstance(batch[0], (tuple,list)):\n",
" transposed = zip(*batch)\n",
" return [numpy_collate(samples) for samples in transposed]\n",
" else:\n",
" return onp.array(batch)\n",
" return np.array(batch)\n",
"\n",
"class NumpyLoader(data.DataLoader):\n",
" def __init__(self, dataset, batch_size=1,\n",
@ -367,7 +367,7 @@
"\n",
"class FlattenAndCast(object):\n",
" def __call__(self, pic):\n",
" return onp.ravel(onp.array(pic, dtype=np.float32))"
" return np.ravel(np.array(pic, dtype=jnp.float32))"
]
},
{
@ -526,13 +526,13 @@
],
"source": [
"# Get the full train dataset (for checking accuracy while training)\n",
"train_images = onp.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)\n",
"train_labels = one_hot(onp.array(mnist_dataset.train_labels), n_targets)\n",
"train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)\n",
"train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)\n",
"\n",
"# Get full test dataset\n",
"mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)\n",
"test_images = np.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=np.float32)\n",
"test_labels = one_hot(onp.array(mnist_dataset_test.test_labels), n_targets)"
"test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)\n",
"test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)"
]
},
{

View File

@ -33,9 +33,9 @@
},
"outputs": [],
"source": [
"import numpy as onp\n",
"import numpy as np\n",
"import jax\n",
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import jit, grad, vmap\n",
"from jax import random"
]
@ -84,7 +84,7 @@
"source": [
"x = random.normal(random.PRNGKey(0), (5000, 5000))\n",
"def f(w, b, x):\n",
" return np.tanh(np.dot(x, w) + b)\n",
" return jnp.tanh(jnp.dot(x, w) + b)\n",
"fast_f = jit(f)"
]
},
@ -194,10 +194,10 @@
"print()\n",
"\n",
"def bar(w, b, x):\n",
" return np.dot(w, x) + b + np.ones(5), x\n",
" return jnp.dot(w, x) + b + jnp.ones(5), x\n",
"print(\"bar\")\n",
"print(\"=====\")\n",
"examine_jaxpr(jax.make_jaxpr(bar)(np.ones((5, 10)), np.ones(5), np.ones(10)))"
"examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))"
]
},
{
@ -257,9 +257,9 @@
"Goal:\n",
"```python\n",
"def f(x):\n",
" return np.exp(np.tanh(x))\n",
" return jnp.exp(jnp.tanh(x))\n",
"f_inv = inverse(f)\n",
"assert np.allclose(f_inv(f(1.0)), 1.0)\n",
"assert jnp.allclose(f_inv(f(1.0)), 1.0)\n",
"```\n",
"\n",
"The way we'll implement this is by (1) tracing `f` into a Jaxpr, then (2) interpreting the Jaxpr *backwards*. While interpreting the Jaxpr backwards, for each equation we'll look up the primitive's inverse in a table and apply it.\n",
@ -280,7 +280,7 @@
"outputs": [],
"source": [
"# Importing Jax functions useful for tracing/interpreting.\n",
"import numpy as onp\n",
"import numpy as np\n",
"from functools import wraps\n",
"\n",
"from jax import api_util\n",
@ -308,7 +308,7 @@
" def pv_like(x):\n",
" # ShapedArrays are abstract values that carry around\n",
" # shape and dtype information\n",
" aval = ShapedArray(onp.shape(x), onp.result_type(x))\n",
" aval = ShapedArray(np.shape(x), np.result_type(x))\n",
" return pe.PartialVal.unknown(aval)\n",
"\n",
" @wraps(fun)\n",
@ -366,8 +366,8 @@
],
"source": [
"def f(x):\n",
" return np.exp(np.tanh(x))\n",
"jaxpr, consts, _ = make_jaxpr2(f)(np.ones(5))\n",
" return jnp.exp(jnp.tanh(x))\n",
"jaxpr, consts, _ = make_jaxpr2(f)(jnp.ones(5))\n",
"print(jaxpr)\n",
"print(consts)"
]
@ -457,8 +457,8 @@
}
],
"source": [
"jaxpr, consts, _ = make_jaxpr2(f)(np.ones(5))\n",
"eval_jaxpr(jaxpr, consts, np.ones(5))"
"jaxpr, consts, _ = make_jaxpr2(f)(jnp.ones(5))\n",
"eval_jaxpr(jaxpr, consts, jnp.ones(5))"
]
},
{
@ -521,8 +521,8 @@
},
"outputs": [],
"source": [
"inverse_registry[lax.exp_p] = np.log\n",
"inverse_registry[lax.tanh_p] = np.arctanh"
"inverse_registry[lax.exp_p] = jnp.log\n",
"inverse_registry[lax.tanh_p] = jnp.arctanh"
]
},
{
@ -639,9 +639,9 @@
],
"source": [
"def f(x):\n",
" return np.exp(np.tanh(x))\n",
" return jnp.exp(jnp.tanh(x))\n",
"f_inv = inverse(f)\n",
"assert np.allclose(f_inv(f(1.0)), 1.0)"
"assert jnp.allclose(f_inv(f(1.0)), 1.0)"
]
},
{
@ -728,7 +728,7 @@
}
],
"source": [
"jit(vmap(grad(inverse(f))))((np.arange(5) + 1.) / 5.)"
"jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)"
]
},
{

View File

@ -97,8 +97,7 @@
"colab": {}
},
"source": [
"# We import as onp to emphasize that we're using vanilla numpy, not jax numpy.\n",
"import numpy as onp\n",
"import numpy as np\n",
"\n",
"# We only need to import JAX's xla_client, not all of JAX.\n",
"from jaxlib import xla_client\n",
@ -138,12 +137,12 @@
"def canonicalize_dtype(dtype):\n",
" \"\"\"We restrict ourselves to 32bit types for this demo.\"\"\"\n",
" _dtype_to_32bit_dtype = {\n",
" str(onp.dtype('int64')): onp.dtype('int32'),\n",
" str(onp.dtype('uint64')): onp.dtype('uint32'),\n",
" str(onp.dtype('float64')): onp.dtype('float32'),\n",
" str(onp.dtype('complex128')): onp.dtype('complex64'),\n",
" str(np.dtype('int64')): np.dtype('int32'),\n",
" str(np.dtype('uint64')): np.dtype('uint32'),\n",
" str(np.dtype('float64')): np.dtype('float32'),\n",
" str(np.dtype('complex128')): np.dtype('complex64'),\n",
" }\n",
" dtype = onp.dtype(dtype)\n",
" dtype = np.dtype(dtype)\n",
" return _dtype_to_32bit_dtype.get(str(dtype), dtype)\n",
"\n",
"def shape_of(value):\n",
@ -151,8 +150,8 @@
" if hasattr(value, 'shape') and hasattr(value, 'dtype'):\n",
" return xla_client.Shape.array_shape(canonicalize_dtype(value.dtype), \n",
" value.shape)\n",
" elif onp.isscalar(value):\n",
" return shape_of(onp.asarray(value))\n",
" elif np.isscalar(value):\n",
" return shape_of(np.asarray(value))\n",
" elif isinstance(value, (tuple, list)):\n",
" return xla_client.Shape.tuple_shape(tuple(shape_of(elt) for elt in value))\n",
" else:\n",
@ -163,8 +162,8 @@
" if isinstance(dtype, str):\n",
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype]\n",
" elif isinstance(dtype, type):\n",
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[onp.dtype(dtype).name]\n",
" elif isinstance(dtype, onp.dtype):\n",
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[np.dtype(dtype).name]\n",
" elif isinstance(dtype, np.dtype):\n",
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype.name]\n",
" else:\n",
" raise TypeError('Unexpected type: {}'.format(type(dtype)))"
@ -197,7 +196,7 @@
"c = xla_client.ComputationBuilder(\"simple_scalar\")\n",
"\n",
"# define a parameter shape and parameter\n",
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), ())\n",
"param_shape = xla_client.Shape.array_shape(np.dtype(np.float32), ())\n",
"x = c.ParameterWithShape(param_shape)\n",
"\n",
"# define computation graph\n",
@ -212,7 +211,7 @@
"compiled_computation = computation.Compile([param_shape,])\n",
"\n",
"# define a host variable with above parameter shape\n",
"host_input = onp.array(3.0, dtype=onp.float32)\n",
"host_input = np.array(3.0, dtype=np.float32)\n",
"\n",
"# place host variable on device and execute\n",
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
@ -251,15 +250,15 @@
"# same as above with vector type:\n",
"\n",
"c = xla_client.ComputationBuilder(\"simple_vector\")\n",
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), (3,))\n",
"param_shape = xla_client.Shape.array_shape(np.dtype(np.float32), (3,))\n",
"x = c.ParameterWithShape(param_shape)\n",
"\n",
"# can also use this function to define a shape from an example:\n",
"#x = c.ParameterFromNumpy(onp.array([0.0, 0.0, 0.0], dtype=onp.float32))\n",
"#x = c.ParameterFromNumpy(np.array([0.0, 0.0, 0.0], dtype=np.float32))\n",
"\n",
"# which is the same as using our convenience function above:\n",
"#x = c.ParameterWithShape(shape_of(onp.array([0.0, 0.0, 0.0], \n",
"# dtype=onp.float32)))\n",
"#x = c.ParameterWithShape(shape_of(np.array([0.0, 0.0, 0.0], \n",
"# dtype=np.float32)))\n",
"\n",
"# chain steps by reference:\n",
"y = c.Sin(x)\n",
@ -267,7 +266,7 @@
"computation = c.Build()\n",
"compiled_computation = computation.Compile([param_shape,])\n",
"\n",
"host_input = onp.array([3.0, 4.0, 5.0], dtype=onp.float32)\n",
"host_input = np.array([3.0, 4.0, 5.0], dtype=np.float32)\n",
"\n",
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
"device_out = compiled_computation.Execute([device_input ,])\n",
@ -322,14 +321,14 @@
"# body computation:\n",
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
"x = bcb.ParameterWithShape(in_shape)\n",
"const1 = bcb.Constant(onp.int32(1))\n",
"const1 = bcb.Constant(np.int32(1))\n",
"y = bcb.Sub(x, const1)\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation:\n",
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
"x = tcb.ParameterWithShape(in_shape)\n",
"const0 = tcb.Constant(onp.int32(0))\n",
"const0 = tcb.Constant(np.int32(0))\n",
"y = tcb.Gt(x, const0)\n",
"test_computation = tcb.Build()\n",
"\n",
@ -342,7 +341,7 @@
"# Now compile and execute:\n",
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"host_input = onp.array(5, dtype=onp.int32)\n",
"host_input = np.array(5, dtype=np.int32)\n",
"\n",
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
"device_out = compiled_computation.Execute([device_input ,])\n",
@ -402,7 +401,7 @@
"x = bcb.GetTupleElement(intuple, 1)\n",
"guard_cntr = bcb.GetTupleElement(intuple, 2)\n",
"new_x = bcb.Sub(x, bcb.Div(bcb.Sub(bcb.Mul(x, x), y), bcb.Add(x, x)))\n",
"result = bcb.Tuple(y, new_x, bcb.Sub(guard_cntr, bcb.Constant(onp.int32(1))))\n",
"result = bcb.Tuple(y, new_x, bcb.Sub(guard_cntr, bcb.Constant(np.int32(1))))\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation -- convergence and max iteration test\n",
@ -413,8 +412,8 @@
"guard_cntr = tcb.GetTupleElement(intuple, 2)\n",
"criterion = tcb.Abs(tcb.Sub(tcb.Mul(x, x), y))\n",
"# stop at convergence criteria or too many iterations\n",
"test = tcb.And(tcb.Gt(criterion, tcb.Constant(onp.float32(converged_delta))), \n",
" tcb.Gt(guard_cntr, tcb.Constant(onp.int32(0))))\n",
"test = tcb.And(tcb.Gt(criterion, tcb.Constant(np.float32(converged_delta))), \n",
" tcb.Gt(guard_cntr, tcb.Constant(np.int32(0))))\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
@ -426,9 +425,9 @@
"# Now compile and execute:\n",
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"y = onp.array(Xsqr, dtype=onp.float32)\n",
"x = onp.array(guess, dtype=onp.float32)\n",
"maxit = onp.array(maxit, dtype=onp.int32)\n",
"y = np.array(Xsqr, dtype=np.float32)\n",
"x = np.array(guess, dtype=np.float32)\n",
"maxit = np.array(maxit, dtype=np.int32)\n",
"\n",
"device_input = xla_client.LocalBuffer.from_pyval((y, x, maxit))\n",
"device_out = compiled_computation.Execute([device_input ,])\n",
@ -483,12 +482,12 @@
"Niter = 200\n",
"matrix_shape = (10, 10)\n",
"in_shape = shape_of(\n",
" (onp.zeros(matrix_shape, dtype=onp.float32), 1)\n",
" (np.zeros(matrix_shape, dtype=np.float32), 1)\n",
")\n",
"# NB: in_shape is the same as the manually constructed:\n",
"# xla_client.Shape.tuple_shape(\n",
"# (xla_client.Shape.array_shape(onp.dtype(onp.float32), matrix_shape), \n",
"# xla_client.Shape.array_shape(onp.dtype(onp.int32), ()))\n",
"# (xla_client.Shape.array_shape(np.dtype(np.float32), matrix_shape), \n",
"# xla_client.Shape.array_shape(np.dtype(np.int32), ()))\n",
"# )\n",
"\n",
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
@ -500,14 +499,14 @@
"Q = bcb.GetTupleElement(QR, 0)\n",
"R = bcb.GetTupleElement(QR, 1)\n",
"RQ = bcb.Dot(R, Q)\n",
"bcb.Tuple(RQ, bcb.Sub(cntr, bcb.Constant(onp.int32(1))))\n",
"bcb.Tuple(RQ, bcb.Sub(cntr, bcb.Constant(np.int32(1))))\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation -- just a for loop condition\n",
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
"intuple = tcb.ParameterWithShape(in_shape)\n",
"cntr = tcb.GetTupleElement(intuple, 1)\n",
"test = tcb.Gt(cntr, tcb.Constant(onp.int32(0)))\n",
"test = tcb.Gt(cntr, tcb.Constant(np.int32(0)))\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
@ -519,9 +518,9 @@
"# Now compile and execute:\n",
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"X = onp.random.random(matrix_shape).astype(onp.float32)\n",
"X = np.random.random(matrix_shape).astype(np.float32)\n",
"X = (X + X.T) / 2.0\n",
"it = onp.array(Niter, dtype=onp.int32)\n",
"it = np.array(Niter, dtype=np.int32)\n",
"\n",
"device_in = xla_client.LocalBuffer.from_pyval((X, it))\n",
"device_out = compiled_computation.Execute([device_in,])\n",
@ -532,11 +531,11 @@
"plt.title('D')\n",
"plt.imshow(host_out[0])\n",
"print('sorted eigenvalues')\n",
"print(onp.sort(eigh_vals))\n",
"print(np.sort(eigh_vals))\n",
"print('sorted eigenvalues from numpy')\n",
"print(onp.sort(onp.linalg.eigh(X)[0]))\n",
"print(np.sort(np.linalg.eigh(X)[0]))\n",
"print('sorted error') \n",
"print(onp.sort(eigh_vals) - onp.sort(onp.linalg.eigh(X)[0]))"
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
],
"execution_count": 0,
"outputs": [
@ -605,8 +604,8 @@
"Niter = 100\n",
"matrix_shape = (10, 10)\n",
"in_shape = shape_of(\n",
" (onp.zeros(matrix_shape, dtype=onp.float32), \n",
" onp.eye(matrix_shape[0]),\n",
" (np.zeros(matrix_shape, dtype=np.float32), \n",
" np.eye(matrix_shape[0]),\n",
" 1)\n",
")\n",
"\n",
@ -621,14 +620,14 @@
"R = bcb.GetTupleElement(QR, 1)\n",
"RQ = bcb.Dot(R, Q)\n",
"Onew = bcb.Dot(O, Q)\n",
"bcb.Tuple(RQ, Onew, bcb.Sub(cntr, bcb.Constant(onp.int32(1))))\n",
"bcb.Tuple(RQ, Onew, bcb.Sub(cntr, bcb.Constant(np.int32(1))))\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation -- just a for loop condition\n",
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
"intuple = tcb.ParameterWithShape(in_shape)\n",
"cntr = tcb.GetTupleElement(intuple, 2)\n",
"test = tcb.Gt(cntr, tcb.Constant(onp.int32(0)))\n",
"test = tcb.Gt(cntr, tcb.Constant(np.int32(0)))\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
@ -640,10 +639,10 @@
"# Now compile and execute:\n",
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"X = onp.random.random(matrix_shape).astype(onp.float32)\n",
"X = np.random.random(matrix_shape).astype(np.float32)\n",
"X = (X + X.T) / 2.0\n",
"Omat = onp.eye(matrix_shape[0], dtype=onp.float32)\n",
"it = onp.array(Niter, dtype=onp.int32)\n",
"Omat = np.eye(matrix_shape[0], dtype=np.float32)\n",
"it = np.array(Niter, dtype=np.int32)\n",
"\n",
"device_in = xla_client.LocalBuffer.from_pyval((X, Omat, it))\n",
"device_out = compiled_computation.Execute([device_in,])\n",
@ -659,13 +658,13 @@
"plt.imshow(eigh_mat)\n",
"plt.figure()\n",
"plt.title('U^T A U')\n",
"plt.imshow(onp.dot(onp.dot(eigh_mat.T, X), eigh_mat))\n",
"plt.imshow(np.dot(np.dot(eigh_mat.T, X), eigh_mat))\n",
"print('sorted eigenvalues')\n",
"print(onp.sort(eigh_vals))\n",
"print(np.sort(eigh_vals))\n",
"print('sorted eigenvalues from numpy')\n",
"print(onp.sort(onp.linalg.eigh(X)[0]))\n",
"print(np.sort(np.linalg.eigh(X)[0]))\n",
"print('sorted error') \n",
"print(onp.sort(eigh_vals) - onp.sort(onp.linalg.eigh(X)[0]))"
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
],
"execution_count": 0,
"outputs": [
@ -752,7 +751,7 @@
"Niter=13\n",
"matrix_shape = (1,1, 20, 20)\n",
"in_shape = shape_of(\n",
" (onp.zeros(matrix_shape, dtype=onp.int32), 1)\n",
" (np.zeros(matrix_shape, dtype=np.int32), 1)\n",
")\n",
"\n",
"# Body computation -- Conway Update\n",
@ -762,8 +761,8 @@
"cntr = bcb.GetTupleElement(intuple, 1)\n",
"# convs require floating-point type\n",
"xf = bcb.ConvertElementType(x, to_xla_type('float32'))\n",
"stamp = bcb.Constant(onp.ones((1,1,3,3), dtype=onp.float32))\n",
"convd = bcb.Conv(xf, stamp, onp.array([1, 1]), xla_client.PaddingType.SAME)\n",
"stamp = bcb.Constant(np.ones((1,1,3,3), dtype=np.float32))\n",
"convd = bcb.Conv(xf, stamp, np.array([1, 1]), xla_client.PaddingType.SAME)\n",
"# logic ops require integer types\n",
"convd = bcb.ConvertElementType(convd, to_xla_type('int32'))\n",
"bool_x = bcb.Eq(x, bcb.ConstantS32Scalar(1))\n",
@ -800,13 +799,13 @@
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"# Set up initial state\n",
"X = onp.zeros(matrix_shape, dtype=onp.int32)\n",
"X[0,0, 5:8, 5:8] = onp.array([[0,1,0],[0,0,1],[1,1,1]])\n",
"X = np.zeros(matrix_shape, dtype=np.int32)\n",
"X[0,0, 5:8, 5:8] = np.array([[0,1,0],[0,0,1],[1,1,1]])\n",
"\n",
"# Evolve\n",
"movie = onp.zeros((Niter,)+matrix_shape[-2:], dtype=onp.int32)\n",
"movie = np.zeros((Niter,)+matrix_shape[-2:], dtype=np.int32)\n",
"for it in range(Niter):\n",
" itr = onp.array(it, dtype=onp.int32)\n",
" itr = np.array(it, dtype=np.int32)\n",
" device_in = xla_client.LocalBuffer.from_pyval((X, itr))\n",
" device_out = compiled_computation.Execute([device_in,])\n",
" movie[it] = device_out.to_py()[0][0,0]\n",

View File

@ -52,7 +52,7 @@
}
},
"source": [
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import grad, jit, vmap\n",
"from jax import random\n",
"\n",
@ -104,7 +104,7 @@
}
},
"source": [
"grad_tanh = grad(np.tanh)\n",
"grad_tanh = grad(jnp.tanh)\n",
"print(grad_tanh(2.0))"
],
"execution_count": 2,
@ -142,8 +142,8 @@
}
},
"source": [
"print(grad(grad(np.tanh))(2.0))\n",
"print(grad(grad(grad(np.tanh)))(2.0))"
"print(grad(grad(jnp.tanh))(2.0))\n",
"print(grad(grad(grad(jnp.tanh)))(2.0))"
],
"execution_count": 3,
"outputs": [
@ -176,24 +176,24 @@
},
"source": [
"def sigmoid(x):\n",
" return 0.5 * (np.tanh(x / 2) + 1)\n",
" return 0.5 * (jnp.tanh(x / 2) + 1)\n",
"\n",
"# Outputs probability of a label being true.\n",
"def predict(W, b, inputs):\n",
" return sigmoid(np.dot(inputs, W) + b)\n",
" return sigmoid(jnp.dot(inputs, W) + b)\n",
"\n",
"# Build a toy dataset.\n",
"inputs = np.array([[0.52, 1.12, 0.77],\n",
"inputs = jnp.array([[0.52, 1.12, 0.77],\n",
" [0.88, -1.08, 0.15],\n",
" [0.52, 0.06, -1.30],\n",
" [0.74, -2.49, 1.39]])\n",
"targets = np.array([True, True, False, True])\n",
"targets = jnp.array([True, True, False, True])\n",
"\n",
"# Training loss is the negative log-likelihood of the training examples.\n",
"def loss(W, b):\n",
" preds = predict(W, b, inputs)\n",
" label_probs = preds * targets + (1 - preds) * (1 - targets)\n",
" return -np.sum(np.log(label_probs))\n",
" return -jnp.sum(jnp.log(label_probs))\n",
"\n",
"# Initialize random model coefficients\n",
"key, W_key, b_key = random.split(key, 3)\n",
@ -304,7 +304,7 @@
"def loss2(params_dict):\n",
" preds = predict(params_dict['W'], params_dict['b'], inputs)\n",
" label_probs = preds * targets + (1 - preds) * (1 - targets)\n",
" return -np.sum(np.log(label_probs))\n",
" return -jnp.sum(jnp.log(label_probs))\n",
"\n",
"print(grad(loss2)({'W': W, 'b': b}))"
],
@ -413,10 +413,10 @@
"# Check W_grad with finite differences in a random direction\n",
"key, subkey = random.split(key)\n",
"vec = random.normal(subkey, W.shape)\n",
"unitvec = vec / np.sqrt(np.vdot(vec, vec))\n",
"unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))\n",
"W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps\n",
"print('W_dirderiv_numerical', W_grad_numerical)\n",
"print('W_dirderiv_autodiff', np.vdot(grad(loss)(W, b), unitvec))"
"print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))"
],
"execution_count": 8,
"outputs": [
@ -508,7 +508,7 @@
},
"source": [
"def hvp(f, x, v):\n",
" return grad(lambda x: np.vdot(grad(f)(x), v))(x)"
" return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)"
],
"execution_count": 0,
"outputs": []
@ -945,7 +945,7 @@
},
"source": [
"def hvp(f, x, v):\n",
" return grad(lambda x: np.vdot(grad(f)(x), v))(x)"
" return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)"
],
"execution_count": 0,
"outputs": []
@ -994,7 +994,7 @@
"id": "XUsye1SwfSFm"
},
"source": [
"Even better, since we didn't have to call `np.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on `jax.numpy`.\n",
"Even better, since we didn't have to call `jnp.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on `jax.numpy`.\n",
"\n",
"Here's an example of how to use it:"
]
@ -1012,16 +1012,16 @@
},
"source": [
"def f(X):\n",
" return np.sum(np.tanh(X)**2)\n",
" return jnp.sum(jnp.tanh(X)**2)\n",
"\n",
"key, subkey1, subkey2 = random.split(key, 3)\n",
"X = random.normal(subkey1, (30, 40))\n",
"V = random.normal(subkey2, (30, 40))\n",
"\n",
"ans1 = hvp(f, (X,), (V,))\n",
"ans2 = np.tensordot(hessian(f)(X), V, 2)\n",
"ans2 = jnp.tensordot(hessian(f)(X), V, 2)\n",
"\n",
"print(np.allclose(ans1, ans2, 1e-4, 1e-4))"
"print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))"
],
"execution_count": 18,
"outputs": [
@ -1086,7 +1086,7 @@
"def hvp_revrev(f, primals, tangents):\n",
" x, = primals\n",
" v, = tangents\n",
" return grad(lambda x: np.vdot(grad(f)(x), v))(x)\n",
" return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)\n",
"\n",
"\n",
"print(\"Forward over reverse\")\n",
@ -1097,7 +1097,7 @@
"%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))\n",
"\n",
"print(\"Naive full Hessian materialization\")\n",
"%timeit -n10 -r3 np.tensordot(hessian(f)(X), V, 2)"
"%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)"
],
"execution_count": 20,
"outputs": [
@ -1158,7 +1158,7 @@
"# First, use a list comprehension to loop over rows in the matrix M.\n",
"def loop_mjp(f, x, M):\n",
" y, vjp_fun = vjp(f, x)\n",
" return np.vstack([vjp_fun(mi) for mi in M])\n",
" return jnp.vstack([vjp_fun(mi) for mi in M])\n",
"\n",
"# Now, use vmap to build a computation that does a single fast matrix-matrix\n",
"# multiply, rather than an outer loop over vector-matrix multiplies.\n",
@ -1179,7 +1179,7 @@
"vmap_vs = vmap_mjp(f, W, M=U)\n",
"%timeit -n10 -r3 vmap_mjp(f, W, M=U)\n",
"\n",
"assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'"
"assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'"
],
"execution_count": 21,
"outputs": [
@ -1211,7 +1211,7 @@
"def loop_jmp(f, x, M):\n",
" # jvp immediately returns the primal and tangent values as a tuple,\n",
" # so we'll compute and select the tangents in a list comprehension\n",
" return np.vstack([jvp(f, (W,), (mi,))[1] for mi in M])\n",
" return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])\n",
"\n",
"def vmap_jmp(f, x, M):\n",
" _jvp = lambda s: jvp(f, (W,), (s,))[1]\n",
@ -1227,7 +1227,7 @@
"print('\\nVmapped Jacobian-Matrix product')\n",
"%timeit -n10 -r3 vmap_jmp(f, W, M=S)\n",
"\n",
"assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'"
"assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'"
],
"execution_count": 22,
"outputs": [
@ -1281,11 +1281,11 @@
" # Use vmap to do a matrix-Jacobian product.\n",
" # Here, the matrix is the Euclidean basis, so we get all\n",
" # entries in the Jacobian at once. \n",
" J, = vmap(vjp_fun, in_axes=0)(np.eye(len(y)))\n",
" J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))\n",
" return J\n",
" return jacfun\n",
"\n",
"assert np.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'"
"assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'"
],
"execution_count": 0,
"outputs": []
@ -1303,11 +1303,11 @@
"def our_jacfwd(f):\n",
" def jacfun(x):\n",
" _jvp = lambda s: jvp(f, (x,), (s,))[1]\n",
" Jt =vmap(_jvp, in_axes=1)(np.eye(len(x)))\n",
" return np.transpose(Jt)\n",
" Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n",
" return jnp.transpose(Jt)\n",
" return jacfun\n",
"\n",
"assert np.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'"
"assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'"
],
"execution_count": 0,
"outputs": []
@ -1351,7 +1351,7 @@
" else:\n",
" raise ValueError\n",
" except ValueError:\n",
" return np.pi * x\n",
" return jnp.pi * x\n",
"\n",
"y, f_vjp = vjp(f, 4.)\n",
"print(jit(f_vjp)(1.))"
@ -1398,7 +1398,7 @@
},
"source": [
"def f(z):\n",
" x, y = np.real(z), np.imag(z)\n",
" x, y = jnp.real(z), jnp.imag(z)\n",
" return u(x, y) + v(x, y) * 1j\n",
"\n",
"def g(x, y):\n",
@ -1465,7 +1465,7 @@
" a, b, c, d = random.uniform(subkey, (4,))\n",
"\n",
" def fun(z):\n",
" x, y = np.real(z), np.imag(z)\n",
" x, y = jnp.real(z), jnp.imag(z)\n",
" return u(x, y) + v(x, y) * 1j\n",
"\n",
" def u(x, y):\n",
@ -1490,7 +1490,7 @@
" grad(u, 1)(x, y) * d +\n",
" grad(v, 0)(x, y) * c * 1j+\n",
" grad(v, 1)(x, y) * d * 1j)\n",
" print(np.allclose(ans, expected))"
" print(jnp.allclose(ans, expected))"
],
"execution_count": 0,
"outputs": []
@ -1567,7 +1567,7 @@
" a, b, c, d = random.uniform(subkey, (4,))\n",
"\n",
" def fun(z):\n",
" x, y = np.real(z), np.imag(z)\n",
" x, y = jnp.real(z), jnp.imag(z)\n",
" return u(x, y) + v(x, y) * 1j\n",
"\n",
" def u(x, y):\n",
@ -1584,7 +1584,7 @@
" # cotangent vector\n",
" key, subkey = random.split(key)\n",
" c, d = random.uniform(subkey, (2,))\n",
" z_bar = np.array(c + d * 1j) # for dtype control\n",
" z_bar = jnp.array(c + d * 1j) # for dtype control\n",
"\n",
" # check vjp\n",
" _, fun_vjp = vjp(fun, z)\n",
@ -1593,7 +1593,7 @@
" grad(v, 0)(x, y) * (-d) +\n",
" grad(u, 1)(x, y) * c * (-1j) +\n",
" grad(v, 1)(x, y) * (-d) * (-1j))\n",
" assert np.allclose(ans, expected, atol=1e-5, rtol=1e-5)"
" assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)"
],
"execution_count": 0,
"outputs": []
@ -1638,7 +1638,7 @@
},
"source": [
"def f(z):\n",
" x, y = np.real(z), np.imag(z)\n",
" x, y = jnp.real(z), jnp.imag(z)\n",
" return x**2 + y**2\n",
"\n",
"z = 3. + 4j\n",
@ -1685,7 +1685,7 @@
},
"source": [
"def f(z):\n",
" return np.sin(z)\n",
" return jnp.sin(z)\n",
"\n",
"z = 3. + 4j\n",
"grad(f, holomorphic=True)(z)"
@ -1729,7 +1729,7 @@
},
"source": [
"def f(z):\n",
" return np.conjugate(z)\n",
" return jnp.conjugate(z)\n",
"\n",
"z = 3. + 4j\n",
"grad(f, holomorphic=True)(z) # f is not actually holomorphic!"
@ -1788,13 +1788,13 @@
}
},
"source": [
"A = np.array([[5., 2.+3j, 5j],\n",
"A = jnp.array([[5., 2.+3j, 5j],\n",
" [2.-3j, 7., 1.+7j],\n",
" [-5j, 1.-7j, 12.]])\n",
"\n",
"def f(X):\n",
" L = np.linalg.cholesky(X)\n",
" return np.sum((L - np.sin(L))**2)\n",
" L = jnp.linalg.cholesky(X)\n",
" return jnp.sum((L - jnp.sin(L))**2)\n",
"\n",
"grad(f, holomorphic=True)(A)"
],

View File

@ -64,7 +64,7 @@
},
"source": [
"### import jax.numpy (almost-drop-in for numpy) and gradient operators.\n",
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import grad"
],
"execution_count": 0,
@ -94,8 +94,8 @@
}
},
"source": [
"f = lambda x : np.exp(x)\n",
"g = lambda x : np.square(x)\n",
"f = lambda x : jnp.exp(x)\n",
"g = lambda x : jnp.square(x)\n",
"print(grad(f)(1.)) # = e^{1}\n",
"print(grad(grad(f))(1.))\n",
"print(grad(grad(grad(f)))(1.))\n",
@ -184,7 +184,7 @@
"def loss(params, inputs, targets):\n",
" # Computes average loss for the batch\n",
" predictions = net_apply(params, inputs)\n",
" return np.mean((targets - predictions)**2)"
" return jnp.mean((targets - predictions)**2)"
],
"execution_count": 0,
"outputs": []
@ -202,8 +202,8 @@
},
"source": [
"# batch the inference across K=100\n",
"xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)) # (k, 1)\n",
"targets = np.sin(xrange_inputs)\n",
"xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)\n",
"targets = jnp.sin(xrange_inputs)\n",
"predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n",
"losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss\n",
"plt.plot(xrange_inputs, predictions, label='prediction')\n",
@ -247,7 +247,7 @@
"colab": {}
},
"source": [
"import numpy as onp\n",
"import numpy as np\n",
"from jax.experimental import optimizers\n",
"from jax.tree_util import tree_multimap # Element-wise manipulation of collections of numpy arrays "
],
@ -292,7 +292,7 @@
},
"source": [
"# batch the inference across K=100\n",
"targets = np.sin(xrange_inputs)\n",
"targets = jnp.sin(xrange_inputs)\n",
"predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n",
"losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss\n",
"plt.plot(xrange_inputs, predictions, label='prediction')\n",
@ -360,7 +360,7 @@
"source": [
"# gradients of gradients test for MAML\n",
"# check numerics\n",
"g = lambda x, y : np.square(x) + y\n",
"g = lambda x, y : jnp.square(x) + y\n",
"x0 = 2.\n",
"y0 = 1.\n",
"print('grad(g)(x0) = {}'.format(grad(g)(x0, y0))) # 2x = 4\n",
@ -432,8 +432,8 @@
"source": [
"x1 = xrange_inputs\n",
"y1 = targets\n",
"x2 = np.array([0.])\n",
"y2 = np.array([0.])\n",
"x2 = jnp.array([0.])\n",
"y2 = jnp.array([0.])\n",
"maml_loss(net_params, x1, y1, x2, y2)"
],
"execution_count": 0,
@ -491,14 +491,14 @@
"# Adam optimization\n",
"for i in range(20000):\n",
" # define the task\n",
" A = onp.random.uniform(low=0.1, high=.5)\n",
" phase = onp.random.uniform(low=0., high=np.pi)\n",
" A = np.random.uniform(low=0.1, high=.5)\n",
" phase = np.random.uniform(low=0., high=jnp.pi)\n",
" # meta-training inner split (K examples)\n",
" x1 = onp.random.uniform(low=-5., high=5., size=(K,1))\n",
" y1 = A * onp.sin(x1 + phase)\n",
" x1 = np.random.uniform(low=-5., high=5., size=(K,1))\n",
" y1 = A * np.sin(x1 + phase)\n",
" # meta-training outer split (1 example). Like cross-validating with respect to one example.\n",
" x2 = onp.random.uniform(low=-5., high=5.)\n",
" y2 = A * onp.sin(x2 + phase)\n",
" x2 = np.random.uniform(low=-5., high=5.)\n",
" y2 = A * np.sin(x2 + phase)\n",
" opt_state, l = step(i, opt_state, x1, y1, x2, y2)\n",
" np_maml_loss.append(l)\n",
" if i % 1000 == 0:\n",
@ -548,13 +548,13 @@
},
"source": [
"# batch the inference across K=100\n",
"targets = np.sin(xrange_inputs)\n",
"targets = jnp.sin(xrange_inputs)\n",
"predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n",
"plt.plot(xrange_inputs, predictions, label='pre-update predictions')\n",
"plt.plot(xrange_inputs, targets, label='target')\n",
"\n",
"x1 = onp.random.uniform(low=-5., high=5., size=(K,1))\n",
"y1 = 1. * onp.sin(x1 + 0.)\n",
"x1 = np.random.uniform(low=-5., high=5., size=(K,1))\n",
"y1 = 1. * np.sin(x1 + 0.)\n",
"\n",
"for i in range(1,5):\n",
" net_params = inner_update(net_params, x1, y1)\n",
@ -619,16 +619,16 @@
" As = []\n",
" phases = []\n",
" for _ in range(outer_batch_size): \n",
" As.append(onp.random.uniform(low=0.1, high=.5))\n",
" phases.append(onp.random.uniform(low=0., high=np.pi))\n",
" As.append(np.random.uniform(low=0.1, high=.5))\n",
" phases.append(np.random.uniform(low=0., high=jnp.pi))\n",
" def get_batch():\n",
" xs, ys = [], []\n",
" for A, phase in zip(As, phases):\n",
" x = onp.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))\n",
" y = A * onp.sin(x + phase)\n",
" x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))\n",
" y = A * np.sin(x + phase)\n",
" xs.append(x)\n",
" ys.append(y)\n",
" return np.stack(xs), np.stack(ys)\n",
" return jnp.stack(xs), jnp.stack(ys)\n",
" x1, y1 = get_batch()\n",
" x2, y2 = get_batch()\n",
" return x1, y1, x2, y2"
@ -734,7 +734,7 @@
"# returns scalar for all tasks.\n",
"def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):\n",
" task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)\n",
" return np.mean(task_losses)\n",
" return jnp.mean(task_losses)\n",
"\n",
"@jit\n",
"def step(i, opt_state, x1, y1, x2, y2):\n",
@ -796,13 +796,13 @@
},
"source": [
"# batch the inference across K=100\n",
"targets = np.sin(xrange_inputs)\n",
"targets = jnp.sin(xrange_inputs)\n",
"predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n",
"plt.plot(xrange_inputs, predictions, label='pre-update predictions')\n",
"plt.plot(xrange_inputs, targets, label='target')\n",
"\n",
"x1 = onp.random.uniform(low=-5., high=5., size=(10,1))\n",
"y1 = 1. * onp.sin(x1 + 0.)\n",
"x1 = np.random.uniform(low=-5., high=5., size=(10,1))\n",
"y1 = 1. * np.sin(x1 + 0.)\n",
"\n",
"for i in range(1,3):\n",
" net_params = inner_update(net_params, x1, y1)\n",
@ -851,8 +851,8 @@
},
"source": [
"# Comparison of maml_loss for task batch size = 1 vs. task batch size = 8\n",
"plt.plot(onp.convolve(np_maml_loss, [.05]*20), label='task_batch=1')\n",
"plt.plot(onp.convolve(np_batched_maml_loss, [.05]*20), label='task_batch=4')\n",
"plt.plot(np.convolve(np_maml_loss, [.05]*20), label='task_batch=1')\n",
"plt.plot(np.convolve(np_batched_maml_loss, [.05]*20), label='task_batch=4')\n",
"plt.ylim(0., 1e-1)\n",
"plt.legend()"
],

View File

@ -60,7 +60,7 @@
},
"outputs": [],
"source": [
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import grad, jit, vmap\n",
"from jax import random"
]
@ -132,17 +132,17 @@
"from jax.scipy.special import logsumexp\n",
"\n",
"def relu(x):\n",
" return np.maximum(0, x)\n",
" return jnp.maximum(0, x)\n",
"\n",
"def predict(params, image):\n",
" # per-example predictions\n",
" activations = image\n",
" for w, b in params[:-1]:\n",
" outputs = np.dot(w, activations) + b\n",
" outputs = jnp.dot(w, activations) + b\n",
" activations = relu(outputs)\n",
" \n",
" final_w, final_b = params[-1]\n",
" logits = np.dot(final_w, activations) + final_b\n",
" logits = jnp.dot(final_w, activations) + final_b\n",
" return logits - logsumexp(logits)"
]
},
@ -267,18 +267,18 @@
},
"outputs": [],
"source": [
"def one_hot(x, k, dtype=np.float32):\n",
"def one_hot(x, k, dtype=jnp.float32):\n",
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
" return np.array(x[:, None] == np.arange(k), dtype)\n",
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
" \n",
"def accuracy(params, images, targets):\n",
" target_class = np.argmax(targets, axis=1)\n",
" predicted_class = np.argmax(batched_predict(params, images), axis=1)\n",
" return np.mean(predicted_class == target_class)\n",
" target_class = jnp.argmax(targets, axis=1)\n",
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",
" return jnp.mean(predicted_class == target_class)\n",
"\n",
"def loss(params, images, targets):\n",
" preds = batched_predict(params, images)\n",
" return -np.mean(preds * targets)\n",
" return -jnp.mean(preds * targets)\n",
"\n",
"@jit\n",
"def update(params, x, y):\n",
@ -325,12 +325,12 @@
"\n",
"# Full train set\n",
"train_images, train_labels = train_data['image'], train_data['label']\n",
"train_images = np.reshape(train_images, (len(train_images), num_pixels))\n",
"train_images = jnp.reshape(train_images, (len(train_images), num_pixels))\n",
"train_labels = one_hot(train_labels, num_labels)\n",
"\n",
"# Full test set\n",
"test_images, test_labels = test_data['image'], test_data['label']\n",
"test_images = np.reshape(test_images, (len(test_images), num_pixels))\n",
"test_images = jnp.reshape(test_images, (len(test_images), num_pixels))\n",
"test_labels = one_hot(test_labels, num_labels)"
]
},
@ -429,7 +429,7 @@
"for epoch in range(num_epochs):\n",
" start_time = time.time()\n",
" for x, y in get_train_batches():\n",
" x = np.reshape(x, (len(x), num_pixels))\n",
" x = jnp.reshape(x, (len(x), num_pixels))\n",
" y = one_hot(y, num_labels)\n",
" params = update(params, x, y)\n",
" epoch_time = time.time() - start_time\n",

View File

@ -37,7 +37,7 @@
"colab": {}
},
"source": [
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import grad, jit, vmap\n",
"from jax import random"
],
@ -98,8 +98,8 @@
},
"source": [
"size = 3000\n",
"x = random.normal(key, (size, size), dtype=np.float32)\n",
"%timeit np.dot(x, x.T).block_until_ready() # runs on the GPU"
"x = random.normal(key, (size, size), dtype=jnp.float32)\n",
"%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU"
],
"execution_count": 0,
"outputs": []
@ -124,9 +124,9 @@
"colab": {}
},
"source": [
"import numpy as onp # original CPU-backed NumPy\n",
"x = onp.random.normal(size=(size, size)).astype(onp.float32)\n",
"%timeit np.dot(x, x.T).block_until_ready()"
"import numpy as np\n",
"x = np.random.normal(size=(size, size)).astype(np.float32)\n",
"%timeit jnp.dot(x, x.T).block_until_ready()"
],
"execution_count": 0,
"outputs": []
@ -151,9 +151,9 @@
"source": [
"from jax import device_put\n",
"\n",
"x = onp.random.normal(size=(size, size)).astype(onp.float32)\n",
"x = np.random.normal(size=(size, size)).astype(np.float32)\n",
"x = device_put(x)\n",
"%timeit np.dot(x, x.T).block_until_ready()"
"%timeit jnp.dot(x, x.T).block_until_ready()"
],
"execution_count": 0,
"outputs": []
@ -186,8 +186,8 @@
"colab": {}
},
"source": [
"x = onp.random.normal(size=(size, size)).astype(onp.float32)\n",
"%timeit onp.dot(x, x.T)"
"x = np.random.normal(size=(size, size)).astype(np.float32)\n",
"%timeit np.dot(x, x.T)"
],
"execution_count": 0,
"outputs": []
@ -237,7 +237,7 @@
},
"source": [
"def selu(x, alpha=1.67, lmbda=1.05):\n",
" return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)\n",
" return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n",
"\n",
"x = random.normal(key, (1000000,))\n",
"%timeit selu(x).block_until_ready()"
@ -290,9 +290,9 @@
},
"source": [
"def sum_logistic(x):\n",
" return np.sum(1.0 / (1.0 + np.exp(-x)))\n",
" return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))\n",
"\n",
"x_small = np.arange(3.)\n",
"x_small = jnp.arange(3.)\n",
"derivative_fn = grad(sum_logistic)\n",
"print(derivative_fn(x_small))"
],
@ -319,8 +319,8 @@
"source": [
"def first_finite_differences(f, x):\n",
" eps = 1e-3\n",
" return np.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)\n",
" for v in np.eye(len(x))])\n",
" return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)\n",
" for v in jnp.eye(len(x))])\n",
"\n",
"\n",
"print(first_finite_differences(sum_logistic, x_small))"
@ -418,7 +418,7 @@
"batched_x = random.normal(key, (10, 100))\n",
"\n",
"def apply_matrix(v):\n",
" return np.dot(mat, v)"
" return jnp.dot(mat, v)"
],
"execution_count": 0,
"outputs": []
@ -442,7 +442,7 @@
},
"source": [
"def naively_batched_apply_matrix(v_batched):\n",
" return np.stack([apply_matrix(v) for v in v_batched])\n",
" return jnp.stack([apply_matrix(v) for v in v_batched])\n",
"\n",
"print('Naively batched')\n",
"%timeit naively_batched_apply_matrix(batched_x).block_until_ready()"
@ -457,7 +457,7 @@
"colab_type": "text"
},
"source": [
"We know how to batch this operation manually. In this case, `np.dot` handles extra batch dimensions transparently."
"We know how to batch this operation manually. In this case, `jnp.dot` handles extra batch dimensions transparently."
]
},
{
@ -470,7 +470,7 @@
"source": [
"@jit\n",
"def batched_apply_matrix(v_batched):\n",
" return np.dot(v_batched, mat.T)\n",
" return jnp.dot(v_batched, mat.T)\n",
"\n",
"print('Manually batched')\n",
"%timeit batched_apply_matrix(batched_x).block_until_ready()"

View File

@ -63,14 +63,14 @@
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import numpy as onp\n",
"import numpy as np\n",
"\n",
"from sklearn.datasets import make_swiss_roll\n",
"\n",
"def sample_batch(size, noise=1.0):\n",
" x, _= make_swiss_roll(size, noise=noise)\n",
" x = x[:, [0, 2]] / 10.0\n",
" return onp.array(x)\n",
" return np.array(x)\n",
"\n",
"plt.scatter(*sample_batch(10**4).T, alpha=0.1)"
],
@ -131,7 +131,7 @@
},
"source": [
"import jax\n",
"import jax.numpy as np\n",
"import jax.numpy as jnp\n",
"from jax.experimental import optimizers\n",
"from jax.experimental import stax\n",
"from functools import partial\n",
@ -167,10 +167,10 @@
" # we use jax.vmap to vectorize jacobian function along batch dimension\n",
" batch_jacobian = jax.vmap(partial(jacobian, net_params))(inputs) # [batch, dim, dim]\n",
" \n",
" trace_jacobian = np.trace(batch_jacobian, axis1=1, axis2=2)\n",
" output_norm_sq = np.square(net_apply(net_params, inputs)).sum(axis=1)\n",
" trace_jacobian = jnp.trace(batch_jacobian, axis1=1, axis2=2)\n",
" output_norm_sq = jnp.square(net_apply(net_params, inputs)).sum(axis=1)\n",
" \n",
" return np.mean(trace_jacobian + 1/2 * output_norm_sq)\n",
" return jnp.mean(trace_jacobian + 1/2 * output_norm_sq)\n",
"\n",
"\n",
"@jax.jit\n",
@ -243,16 +243,16 @@
" clear_output(True)\n",
" plt.figure(figsize=[16, 8])\n",
" plt.subplot(1, 2, 1)\n",
" plt.title(\"mean loss = %.3f\" % np.mean(np.array(loss_history[-32:])))\n",
" plt.scatter(np.arange(len(loss_history)), loss_history)\n",
" plt.title(\"mean loss = %.3f\" % jnp.mean(jnp.array(loss_history[-32:])))\n",
" plt.scatter(jnp.arange(len(loss_history)), loss_history)\n",
" plt.grid()\n",
" \n",
" plt.subplot(1, 2, 2)\n",
" net_params = get_params(opt_state)\n",
" xx = np.stack(np.meshgrid(np.linspace(-1.5, 2.0, 50), np.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)\n",
" xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 2.0, 50), jnp.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)\n",
" scores = net_apply(net_params, xx)\n",
" scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
" scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)\n",
" scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
" scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)\n",
"\n",
" plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')\n",
" plt.xlim(-1.5, 2.0)\n",
@ -301,10 +301,10 @@
"plt.figure(figsize=[16, 16])\n",
"\n",
"net_params = get_params(opt_state)\n",
"xx = np.stack(np.meshgrid(np.linspace(-1.5, 1.5, 50), np.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)\n",
"xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 1.5, 50), jnp.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)\n",
"scores = net_apply(net_params, xx)\n",
"scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
"scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)\n",
"scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
"scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)\n",
"\n",
"plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')\n",
"plt.scatter(*sample_batch(10_000).T, alpha=0.25)"
@ -412,11 +412,11 @@
" for t in range(num_steps):\n",
" key, subkey = jax.random.split(key)\n",
" z_t = jax.random.normal(subkey, shape=x_t.shape)\n",
" x_t = x_t + eps / 2 * net_apply(net_params, x_t) + np.sqrt(eps) * temperature * z_t\n",
" x_t = x_t + eps / 2 * net_apply(net_params, x_t) + jnp.sqrt(eps) * temperature * z_t\n",
" x_sequence.append(x_t)\n",
" eps *= eps_decay\n",
" \n",
" return np.stack(x_sequence)"
" return jnp.stack(x_sequence)"
],
"execution_count": 0,
"outputs": []
@ -438,7 +438,7 @@
"key = jax.random.PRNGKey(42)\n",
"net_params = get_params(opt_state)\n",
"\n",
"for x_initial in np.array([[-1.5, -1.5], [0, 0], [1.5, 0]]):\n",
"for x_initial in jnp.array([[-1.5, -1.5], [0, 0], [1.5, 0]]):\n",
" key, subkey = jax.random.split(key)\n",
" # sample x sequence\n",
" xx = sample_langevin(x_initial, key=subkey, net_params=net_params)\n",
@ -446,16 +446,16 @@
"\n",
" # draw arrows for each mcmc step\n",
" deltas = (xx[1:] - xx[:-1])\n",
" deltas = deltas - deltas / np.linalg.norm(deltas, keepdims=True, axis=-1) * 0.04\n",
" deltas = deltas - deltas / jnp.linalg.norm(deltas, keepdims=True, axis=-1) * 0.04\n",
" for i, arrow in enumerate(deltas):\n",
" plt.arrow(xx[i][0], xx[i][1], arrow[0], arrow[1], width=1e-4, head_width=2e-2, color=\"orange\")\n",
" \n",
"# plot data points and gradients\n",
"plt.plot()\n",
"xx = np.stack(np.meshgrid(np.linspace(-1.5, 1.5, 50), np.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)\n",
"xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 1.5, 50), jnp.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)\n",
"scores = net_apply(net_params, xx)\n",
"scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
"scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)\n",
"scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
"scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)\n",
"plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')\n",
"plt.scatter(*sample_batch(10_000).T, alpha=0.025)"
],
@ -516,7 +516,7 @@
"@jax.jit\n",
"def compute_ssm_loss(net_params, inputs, key):\n",
" apply = jax.jit(partial(net_apply, net_params))\n",
" batch_dot = partial(np.einsum, 'bu,bu->b')\n",
" batch_dot = partial(jnp.einsum, 'bu,bu->b')\n",
" \n",
" # generate random vectors from N(0, I)\n",
" v = jax.random.normal(key, shape=inputs.shape)\n",
@ -524,7 +524,7 @@
" # predict score and comput jacobian of score times v\n",
" score, jac_v = jax.jvp(apply, [inputs], [v])\n",
" \n",
" return np.mean(batch_dot(v, jac_v) + 1/2 * batch_dot(v, score) ** 2)\n",
" return jnp.mean(batch_dot(v, jac_v) + 1/2 * batch_dot(v, score) ** 2)\n",
"\n",
"@jax.jit\n",
"def train_step(step_i, opt_state, batch, key):\n",
@ -588,16 +588,16 @@
" clear_output(True)\n",
" plt.figure(figsize=[16, 8])\n",
" plt.subplot(1, 2, 1)\n",
" plt.title(\"mean loss = %.3f\" % np.mean(np.array(loss_history[-32:])))\n",
" plt.scatter(np.arange(len(loss_history)), loss_history)\n",
" plt.title(\"mean loss = %.3f\" % jnp.mean(jnp.array(loss_history[-32:])))\n",
" plt.scatter(jnp.arange(len(loss_history)), loss_history)\n",
" plt.grid()\n",
" \n",
" plt.subplot(1, 2, 2)\n",
" net_params = get_params(opt_state)\n",
" xx = np.stack(np.meshgrid(np.linspace(-1.5, 2.0, 50), np.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)\n",
" xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 2.0, 50), jnp.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)\n",
" scores = net_apply(net_params, xx)\n",
" scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
" scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)\n",
" scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)\n",
" scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)\n",
"\n",
" plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')\n",
" plt.xlim(-1.5, 2.0)\n",
@ -644,7 +644,7 @@
},
"source": [
"from sklearn.datasets import load_digits\n",
"import numpy as old_np\n",
"import numpy as np\n",
"\n",
"X, _ = load_digits(return_X_y=True)\n",
"\n",
@ -654,8 +654,8 @@
" \n",
"\n",
"def sample_batch(size, noise=0.1):\n",
" ix = old_np.random.randint(0, len(X), size=size)\n",
" return np.array(X[ix] / 16 + noise * old_np.random.randn(size, 64))"
" ix = np.random.randint(0, len(X), size=size)\n",
" return jnp.array(X[ix] / 16 + noise * np.random.randn(size, 64))"
],
"execution_count": 32,
"outputs": [
@ -721,8 +721,8 @@
" \n",
" if i % 500 == 0:\n",
" clear_output(True)\n",
" plt.title(\"mean loss = %.3f\" % np.mean(np.array(loss_history[-32:])))\n",
" plt.scatter(np.arange(len(loss_history)), loss_history)\n",
" plt.title(\"mean loss = %.3f\" % jnp.mean(jnp.array(loss_history[-32:])))\n",
" plt.scatter(jnp.arange(len(loss_history)), loss_history)\n",
" plt.show()"
],
"execution_count": 35,

View File

@ -59,12 +59,12 @@
"import jax\n",
"\n",
"from jax import lax\n",
"from jax import numpy as np\n",
"from jax import scipy\n",
"import jax.numpy as jnp\n",
"import jax.scipy as jsp\n",
"from jax import random\n",
"\n",
"import numpy as onp\n",
"import scipy as oscipy"
"import numpy as np\n",
"import scipy as sp"
],
"execution_count": 0,
"outputs": []
@ -87,14 +87,14 @@
"colab": {}
},
"source": [
"onp.random.seed(10009)\n",
"np.random.seed(10009)\n",
"\n",
"num_features = 10\n",
"num_points = 100\n",
"\n",
"true_beta = onp.random.randn(num_features).astype(np.float32)\n",
"all_x = onp.random.randn(num_points, num_features).astype(np.float32)\n",
"y = (onp.random.rand(num_points) < oscipy.special.expit(all_x.dot(true_beta))).astype(np.int32)"
"true_beta = np.random.randn(num_features).astype(jnp.float32)\n",
"all_x = np.random.randn(num_points, num_features).astype(jnp.float32)\n",
"y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)"
],
"execution_count": 0,
"outputs": []
@ -165,9 +165,9 @@
"source": [
"def log_joint(beta):\n",
" result = 0.\n",
" # Note that no `axis` parameter is provided to `np.sum`.\n",
" result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.))\n",
" result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))\n",
" # Note that no `axis` parameter is provided to `jnp.sum`.\n",
" result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))\n",
" result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))\n",
" return result"
],
"execution_count": 0,
@ -185,7 +185,7 @@
}
},
"source": [
"log_joint(onp.random.randn(num_features))"
"log_joint(np.random.randn(num_features))"
],
"execution_count": 13,
"outputs": [
@ -218,9 +218,9 @@
"# This doesn't work, because we didn't write `log_prob()` to handle batching.\n",
"try:\n",
" batch_size = 10\n",
" batched_test_beta = onp.random.randn(batch_size, num_features)\n",
" batched_test_beta = np.random.randn(batch_size, num_features)\n",
"\n",
" log_joint(onp.random.randn(batch_size, num_features))\n",
" log_joint(np.random.randn(batch_size, num_features))\n",
"except ValueError as e:\n",
" print(\"Caught expected exception \" + str(e))"
],
@ -258,12 +258,12 @@
" # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis\n",
" # or setting it incorrectly yields an error; at worst, it silently changes the\n",
" # semantics of the model.\n",
" result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.),\n",
" result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),\n",
" axis=-1)\n",
" # Note the multiple transposes. Getting this right is not rocket science,\n",
" # but it's also not totally mindless. (I didn't get it right on the first\n",
" # try.)\n",
" result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta.T).T)),\n",
" result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),\n",
" axis=-1)\n",
" return result"
],
@ -283,7 +283,7 @@
},
"source": [
"batch_size = 10\n",
"batched_test_beta = onp.random.randn(batch_size, num_features)\n",
"batched_test_beta = np.random.randn(batch_size, num_features)\n",
"\n",
"batched_log_joint(batched_test_beta)"
],
@ -383,9 +383,9 @@
"@jax.jit\n",
"def log_joint(beta):\n",
" result = 0.\n",
" # Note that no `axis` parameter is provided to `np.sum`.\n",
" result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=10.))\n",
" result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))\n",
" # Note that no `axis` parameter is provided to `jnp.sum`.\n",
" result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))\n",
" result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))\n",
" return result\n",
"\n",
"batched_log_joint = jax.jit(jax.vmap(log_joint))"
@ -412,8 +412,8 @@
},
"source": [
"def elbo(beta_loc, beta_log_scale, epsilon):\n",
" beta_sample = beta_loc + np.exp(beta_log_scale) * epsilon\n",
" return np.mean(batched_log_joint(beta_sample), 0) + np.sum(beta_log_scale - 0.5 * onp.log(2*onp.pi))\n",
" beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n",
" return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n",
" \n",
"elbo = jax.jit(elbo, static_argnums=(1, 2))\n",
"elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))"
@ -452,8 +452,8 @@
"\n",
"key = random.PRNGKey(10003)\n",
"\n",
"beta_loc = np.zeros(num_features, np.float32)\n",
"beta_log_scale = np.zeros(num_features, np.float32)\n",
"beta_loc = jnp.zeros(num_features, jnp.float32)\n",
"beta_log_scale = jnp.zeros(num_features, jnp.float32)\n",
"\n",
"step_size = 0.01\n",
"batch_size = 128\n",
@ -603,8 +603,8 @@
"source": [
"figure(figsize=(7, 7))\n",
"plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n",
"plot(true_beta, beta_loc + 2*np.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n",
"plot(true_beta, beta_loc - 2*np.exp(beta_log_scale), 'r.')\n",
"plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n",
"plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n",
"plot_scale = 3\n",
"plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n",
"xlabel('True beta')\n",

View File

@ -94,7 +94,7 @@ reference cycles.
Here is a simple example::
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax import numpy as np
import jax.numpy as jnp
# The structured value to be transformed
value_structured = [1., (2., 3.)]
@ -129,7 +129,7 @@ treated as leaves::
(1., {'b': 2., 'a': 3.}),
1.,
None,
np.zeros(2),
jnp.zeros(2),
Point(1., 2.)
]
def show_example(structured):

View File

@ -9,9 +9,9 @@ surprising bugs where a silent rank promotion masks an underlying shape error.
Here's an example of rank promotion:
>>> import numpy as onp
>>> x = onp.arange(12).reshape(4, 3)
>>> y = onp.array([0, 1, 0])
>>> import numpy as np
>>> x = np.arange(12).reshape(4, 3)
>>> y = np.array([0, 1, 0])
>>> x + y
array([[ 0, 2, 2],
[ 3, 5, 5],

View File

@ -60,21 +60,21 @@ following table, where, for example
</table><p>
.. The table above was generated by the following Python code.
import numpy as onp
from jax import numpy as np
import numpy as np
import jax.numpy as jnp
types = [onp.bool_, onp.uint8, onp.uint16, onp.uint32, onp.uint64,
onp.int8, onp.int16, onp.int32, onp.int64,
np.bfloat16, onp.float16, onp.float32, onp.float64,
onp.complex64, onp.complex128]
types = [np.bool_, np.uint8, np.uint16, np.uint32, np.uint64,
np.int8, np.int16, np.int32, np.int64,
jnp.bfloat16, np.float16, np.float32, np.float64,
np.complex64, np.complex128]
def name(d):
d = onp.dtype(d)
if d == onp.dtype(np.bfloat16):
d = np.dtype(d)
if d == np.dtype(jnp.bfloat16):
return "bf"
return "{}{}".format(
d.kind,
d.itemsize // 2 if onp.issubdtype(d, onp.complexfloating) else d.itemsize)
d.itemsize // 2 if np.issubdtype(d, np.complexfloating) else d.itemsize)
out = "<tr><th></th>"
for t in types:
@ -84,8 +84,8 @@ following table, where, for example
for t1 in types:
out += "<tr><td>{}</td>".format(name(t1))
for t2 in types:
t = np.promote_types(t1, t2)
different = np.bfloat16 in (t1, t2) or t != onp.promote_types(t1, t2)
t = jnp.promote_types(t1, t2)
different = jnp.bfloat16 in (t1, t2) or t != np.promote_types(t1, t2)
out += "<td{}>{}</td>".format(" class=\"d\"" if different else "", name(t))
out += "</tr>\n"

View File

@ -28,7 +28,7 @@ pairs can be composed in series using `stax.serial` or in parallel using
Heres an example:
```python
import jax.numpy as np
import jax.numpy as jnp
from jax import random
from jax.experimental import stax
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
@ -48,7 +48,7 @@ in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(rng, in_shape)
# Apply network to dummy inputs
inputs = np.zeros((128, 28, 28, 1))
inputs = jnp.zeros((128, 28, 28, 1))
predictions = net_apply(net_params, inputs)
```
@ -74,7 +74,7 @@ from jax import jit, grad
def loss(params, batch):
inputs, targets = batch
predictions = net_apply(params, inputs)
return np.sum((predictions - targets)**2)
return jnp.sum((predictions - targets)**2)
# Use optimizers to set optimizer initialization and update functions
opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9)
@ -87,7 +87,7 @@ def step(i, opt_state, batch):
return opt_update(i, g, opt_state)
# Dummy input data stream
data_generator = ((np.zeros((128, 28, 28, 1)), np.zeros((128, 10)))
data_generator = ((jnp.zeros((128, 28, 28, 1)), jnp.zeros((128, 10)))
for _ in range(10))
# Optimize parameters in a loop