Replace np -> jnp, onp -> np in examples/ (#2971)

For context, see #2370
This commit is contained in:
Peter Hawkins 2020-05-05 15:45:07 -04:00 committed by GitHub
parent b543652332
commit d59ecddfe8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 173 additions and 173 deletions

View File

@ -25,7 +25,7 @@ import matplotlib.pyplot as plt
from jax.api import jit, grad, vmap
from jax import random
from jax.experimental import optimizers
import jax.numpy as np
import jax.numpy as jnp
import jax.scipy.stats.norm as norm
@ -33,11 +33,11 @@ import jax.scipy.stats.norm as norm
def diag_gaussian_sample(rng, mean, log_std):
# Take a single sample from a diagonal multivariate Gaussian.
return mean + np.exp(log_std) * random.normal(rng, mean.shape)
return mean + jnp.exp(log_std) * random.normal(rng, mean.shape)
def diag_gaussian_logpdf(x, mean, log_std):
# Evaluate a single point on a diagonal multivariate Gaussian.
return np.sum(vmap(norm.logpdf)(x, mean, np.exp(log_std)))
return jnp.sum(vmap(norm.logpdf)(x, mean, jnp.exp(log_std)))
def elbo(logprob, rng, mean, log_std):
# Single-sample Monte Carlo estimate of the variational lower bound.
@ -48,7 +48,7 @@ def batch_elbo(logprob, rng, params, num_samples):
# Average over a batch of random samples.
rngs = random.split(rng, num_samples)
vectorized_elbo = vmap(partial(elbo, logprob), in_axes=(0, None, None))
return np.mean(vectorized_elbo(rngs, *params))
return jnp.mean(vectorized_elbo(rngs, *params))
# ========= Helper function for plotting. =========
@ -56,10 +56,10 @@ def batch_elbo(logprob, rng, params, num_samples):
@partial(jit, static_argnums=(0, 1, 2, 4))
def _mesh_eval(func, x_limits, y_limits, params, num_ticks):
# Evaluate func on a 2D grid defined by x_limits and y_limits.
x = np.linspace(*x_limits, num=num_ticks)
y = np.linspace(*y_limits, num=num_ticks)
X, Y = np.meshgrid(x, y)
xy_vec = np.stack([X.ravel(), Y.ravel()]).T
x = jnp.linspace(*x_limits, num=num_ticks)
y = jnp.linspace(*y_limits, num=num_ticks)
X, Y = jnp.meshgrid(x, y)
xy_vec = jnp.stack([X.ravel(), Y.ravel()]).T
zs = vmap(func, in_axes=(0, None))(xy_vec, params)
return X, Y, zs.reshape(X.shape)
@ -69,7 +69,7 @@ def mesh_eval(func, x_limits, y_limits, params, num_ticks=101):
# ========= Define an intractable unnormalized density =========
def funnel_log_density(params):
return norm.logpdf(params[0], 0, np.exp(params[1])) + \
return norm.logpdf(params[0], 0, jnp.exp(params[1])) + \
norm.logpdf(params[1], 0, 1.35)
@ -88,8 +88,8 @@ if __name__ == "__main__":
plt.show(block=False)
x_limits = [-2, 2]
y_limits = [-4, 2]
target_dist = lambda x, _: np.exp(funnel_log_density(x))
approx_dist = lambda x, params: np.exp(diag_gaussian_logpdf(x, *params))
target_dist = lambda x, _: jnp.exp(funnel_log_density(x))
approx_dist = lambda x, params: jnp.exp(diag_gaussian_logpdf(x, *params))
def callback(params, t):
print("Iteration {} lower bound {}".format(t, objective(params, t)))
@ -117,8 +117,8 @@ if __name__ == "__main__":
# Set up optimizer.
D = 2
init_mean = np.zeros(D)
init_std = np.zeros(D)
init_mean = jnp.zeros(D)
init_std = jnp.zeros(D)
init_params = (init_mean, init_std)
opt_init, opt_update, get_params = optimizers.momentum(step_size=0.1, mass=0.9)
opt_state = opt_init(init_params)

View File

@ -18,7 +18,7 @@ Model-predictive non-linear control example.
import collections
from jax import lax, grad, jacfwd, jacobian, vmap
import jax.numpy as np
import jax.numpy as jnp
import jax.ops as jo
@ -52,8 +52,8 @@ ControlSpec = collections.namedtuple(
LqrSpec = collections.namedtuple('LqrSpec', 'Q q R r M A B')
dot = np.dot
mm = np.matmul
dot = jnp.dot
mm = jnp.matmul
def mv(mat, vec):
@ -75,7 +75,7 @@ def fori_loop(lo, hi, loop, init):
def scan_fori_loop(lo, hi, loop, init):
def scan_f(x, t):
return loop(t, x), ()
x, _ = lax.scan(scan_f, init, np.arange(lo, hi))
x, _ = lax.scan(scan_f, init, jnp.arange(lo, hi))
return x
@ -84,7 +84,7 @@ def trajectory(dynamics, U, x0):
T, _ = U.shape
d, = x0.shape
X = np.zeros((T + 1, d))
X = jnp.zeros((T + 1, d))
X = jo.index_update(X, jo.index[0], x0)
def loop(t, X):
@ -110,8 +110,8 @@ def make_lqr_approx(p):
def approx(X, U):
assert X.shape[0] == T + 1 and U.shape[0] == T
U_pad = np.vstack((U, np.zeros((1,) + U.shape[1:])))
Q, q, R, r, M, A, B = _approx(np.arange(T + 1), X, U_pad)
U_pad = jnp.vstack((U, jnp.zeros((1,) + U.shape[1:])))
Q, q, R, r, M, A, B = _approx(jnp.arange(T + 1), X, U_pad)
return LqrSpec(Q, q, R[:T], r[:T], M[:T], A[:T], B[:T])
return approx
@ -122,8 +122,8 @@ def lqr_solve(spec):
T, control_dim, _ = spec.R.shape
_, state_dim, _ = spec.Q.shape
K = np.zeros((T, control_dim, state_dim))
k = np.zeros((T, control_dim))
K = jnp.zeros((T, control_dim, state_dim))
k = jnp.zeros((T, control_dim))
def rev_loop(t_, state):
t = T - t_ - 1
@ -139,8 +139,8 @@ def lqr_solve(spec):
G = R + mm(BtP, B)
H = mm(BtP, A) + M.T
h = r + mv(B.T, p)
K_ = -np.linalg.solve(G + EPS * np.eye(G.shape[0]), H)
k_ = -np.linalg.solve(G + EPS * np.eye(G.shape[0]), h)
K_ = -jnp.linalg.solve(G + EPS * jnp.eye(G.shape[0]), H)
k_ = -jnp.linalg.solve(G + EPS * jnp.eye(G.shape[0]), h)
P_ = Q + mm(AtP, A) + mm(K_.T, H)
p_ = q + mv(A.T, p) + mv(K_.T, h)
@ -170,8 +170,8 @@ def lqr_predict(spec, x0):
U = jo.index_update(U, jo.index[t], u)
return spec, X, U
U = np.zeros((T, control_dim))
X = np.zeros((T + 1, state_dim))
U = jnp.zeros((T, control_dim))
X = jnp.zeros((T + 1, state_dim))
X = jo.index_update(X, jo.index[0], x0)
_, X, U = fori_loop(0, T, fwd_loop, (spec, X, U))
return X, U
@ -186,7 +186,7 @@ def ilqr(iterations, p, x0, U):
def loop(_, state):
X, U = state
p_lqr = lqr_approx(X, U)
dX, dU = lqr_predict(p_lqr, np.zeros_like(x0))
dX, dU = lqr_predict(p_lqr, jnp.zeros_like(x0))
U = U + dU
X = trajectory(p.dynamics, U, X[0] + dX[0])
return X, U
@ -200,7 +200,7 @@ def mpc_predict(solver, p, x0, U):
T = p.horizon
def zero_padded_controls_window(U, t):
U_pad = np.vstack((U, np.zeros(U.shape)))
U_pad = jnp.vstack((U, jnp.zeros(U.shape)))
return lax.dynamic_slice_in_dim(U_pad, t, T, axis=0)
def loop(t, state):
@ -218,6 +218,6 @@ def mpc_predict(solver, p, x0, U):
U = jo.index_update(U, jo.index[t], ut)
return X, U
X = np.zeros((T + 1, p.state_dim))
X = jnp.zeros((T + 1, p.state_dim))
X = jo.index_update(X, jo.index[0], x0)
return fori_loop(0, T, loop, (X, U))

View File

@ -16,11 +16,11 @@ from functools import partial
from unittest import SkipTest
from absl.testing import absltest
import numpy as onp
import numpy as np
from jax import lax
from jax import test_util as jtu
import jax.numpy as np
import jax.numpy as jnp
from examples import control
@ -30,19 +30,19 @@ FLAGS = config.FLAGS
def one_step_lqr(dim, T):
Q = np.stack(T * (np.eye(dim),))
q = np.zeros((T, dim))
R = np.zeros((T, dim, dim))
r = np.zeros((T, dim))
M = np.zeros((T, dim, dim))
A = np.stack(T * (np.eye(dim),))
B = np.stack(T * (np.eye(dim),))
Q = jnp.stack(T * (jnp.eye(dim),))
q = jnp.zeros((T, dim))
R = jnp.zeros((T, dim, dim))
r = jnp.zeros((T, dim))
M = jnp.zeros((T, dim, dim))
A = jnp.stack(T * (jnp.eye(dim),))
B = jnp.stack(T * (jnp.eye(dim),))
return control.LqrSpec(Q, q, R, r, M, A, B)
def control_from_lqr(lqr):
T, dim, _ = lqr.Q.shape
dot = np.dot
dot = jnp.dot
def cost(t, x, u):
return (
@ -59,7 +59,7 @@ def control_from_lqr(lqr):
def one_step_control(dim, T):
def cost(t, x, u):
return np.dot(x, x)
return jnp.dot(x, x)
def dynamics(t, x, u):
return x + u
@ -77,33 +77,33 @@ class ControlExampleTest(jtu.JaxTestCase):
T = 10
U = np.ones((T, 1))
X = control.trajectory(dynamics, U, np.zeros(1))
expected = np.arange(T + 1) % num_states
expected = np.reshape(expected, (T + 1, 1))
U = jnp.ones((T, 1))
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.arange(T + 1) % num_states
expected = jnp.reshape(expected, (T + 1, 1))
self.assertAllClose(X, expected, check_dtypes=False)
U = 2 * np.ones((T, 1))
X = control.trajectory(dynamics, U, np.zeros(1))
expected = np.cumsum(2 * np.ones(T)) % num_states
expected = np.concatenate((np.zeros(1), expected))
expected = np.reshape(expected, (T + 1, 1))
U = 2 * jnp.ones((T, 1))
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.cumsum(2 * jnp.ones(T)) % num_states
expected = jnp.concatenate((jnp.zeros(1), expected))
expected = jnp.reshape(expected, (T + 1, 1))
self.assertAllClose(X, expected, check_dtypes=False)
def testTrajectoryTimeVarying(self):
T = 6
def clip(x, lo, hi):
return np.minimum(hi, np.maximum(lo, x))
return jnp.minimum(hi, jnp.maximum(lo, x))
def dynamics(t, x, u):
# computes `(x + u) if t > T else 0`
return (x + u) * clip(t - T, 0, 1)
U = np.ones((2 * T, 1))
X = control.trajectory(dynamics, U, np.zeros(1))
expected = np.concatenate((np.zeros(T + 1), np.arange(T)))
expected = np.reshape(expected, (2 * T + 1, 1))
U = jnp.ones((2 * T, 1))
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T)))
expected = jnp.reshape(expected, (2 * T + 1, 1))
self.assertAllClose(X, expected, check_dtypes=True)
@ -112,20 +112,20 @@ class ControlExampleTest(jtu.JaxTestCase):
def position(x):
'''finds the index of a standard basis vector, e.g. [0, 1, 0] -> 1'''
x = np.cumsum(x)
x = jnp.cumsum(x)
x = 1 - x
return np.sum(x, dtype=np.int32)
return jnp.sum(x, dtype=jnp.int32)
def dynamics(t, x, u):
'''moves the next standard basis vector'''
idx = (position(x) + u[0]) % num_states
return lax.dynamic_slice_in_dim(np.eye(num_states), idx, 1)[0]
return lax.dynamic_slice_in_dim(jnp.eye(num_states), idx, 1)[0]
T = 8
U = np.ones((T, 1), dtype=np.int32)
X = control.trajectory(dynamics, U, np.eye(num_states, dtype=np.int32)[0])
expected = np.vstack((np.eye(num_states),) * 3)
U = jnp.ones((T, 1), dtype=jnp.int32)
X = control.trajectory(dynamics, U, jnp.eye(num_states, dtype=jnp.int32)[0])
expected = jnp.vstack((jnp.eye(num_states),) * 3)
self.assertAllClose(X, expected, check_dtypes=True)
@ -133,13 +133,13 @@ class ControlExampleTest(jtu.JaxTestCase):
dim, T = 2, 10
p = one_step_lqr(dim, T)
K, k = control.lqr_solve(p)
K_ = -np.stack(T * (np.eye(dim),))
K_ = -jnp.stack(T * (jnp.eye(dim),))
self.assertAllClose(K, K_, check_dtypes=True, atol=1e-6, rtol=1e-6)
self.assertAllClose(k, np.zeros((T, dim)), check_dtypes=True)
self.assertAllClose(k, jnp.zeros((T, dim)), check_dtypes=True)
def testLqrPredict(self):
randn = onp.random.RandomState(0).randn
randn = np.random.RandomState(0).randn
dim, T = 2, 10
p = one_step_lqr(dim, T)
x0 = randn(dim)
@ -147,35 +147,35 @@ class ControlExampleTest(jtu.JaxTestCase):
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True,
atol=1e-6, rtol=1e-6)
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True,
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True,
atol=1e-6, rtol=1e-6)
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True,
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True,
atol=1e-6, rtol=1e-6)
def testIlqrWithLqrProblem(self):
randn = onp.random.RandomState(0).randn
randn = np.random.RandomState(0).randn
dim, T, num_iters = 2, 10, 3
lqr = one_step_lqr(dim, T)
p = control_from_lqr(lqr)
x0 = randn(dim)
X, U = control.ilqr(num_iters, p, x0, np.zeros((T, dim)))
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True)
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
def testIlqrWithLqrProblemSpecifiedGenerally(self):
randn = onp.random.RandomState(0).randn
randn = np.random.RandomState(0).randn
dim, T, num_iters = 2, 10, 3
p = one_step_control(dim, T)
x0 = randn(dim)
X, U = control.ilqr(num_iters, p, x0, np.zeros((T, dim)))
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True)
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
def testIlqrWithNonlinearProblem(self):
@ -188,40 +188,40 @@ class ControlExampleTest(jtu.JaxTestCase):
T, num_iters, d = 10, 7, 1
p = control.ControlSpec(cost, dynamics, T, d, d)
x0 = np.array([0.2])
X, U = control.ilqr(num_iters, p, x0, 1e-5 * np.ones((T, d)))
x0 = jnp.array([0.2])
X, U = control.ilqr(num_iters, p, x0, 1e-5 * jnp.ones((T, d)))
assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
assert_close(X[0], x0)
assert_close(U[0] ** 2., x0 ** 2.)
assert_close(X[1:], np.zeros((T, d)))
assert_close(U[1:], np.zeros((T - 1, d)))
assert_close(X[1:], jnp.zeros((T, d)))
assert_close(U[1:], jnp.zeros((T - 1, d)))
def testMpcWithLqrProblem(self):
randn = onp.random.RandomState(0).randn
randn = np.random.RandomState(0).randn
dim, T, num_iters = 2, 10, 3
lqr = one_step_lqr(dim, T)
p = control_from_lqr(lqr)
x0 = randn(dim)
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, np.zeros((T, dim)))
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True)
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
def testMpcWithLqrProblemSpecifiedGenerally(self):
randn = onp.random.RandomState(0).randn
randn = np.random.RandomState(0).randn
dim, T, num_iters = 2, 10, 3
p = one_step_control(dim, T)
x0 = randn(dim)
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, np.zeros((T, dim)))
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True)
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
def testMpcWithNonlinearProblem(self):
@ -234,14 +234,14 @@ class ControlExampleTest(jtu.JaxTestCase):
T, num_iters, d = 10, 7, 1
p = control.ControlSpec(cost, dynamics, T, d, d)
x0 = np.array([0.2])
x0 = jnp.array([0.2])
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, 1e-5 * np.ones((T, d)))
X, U = control.mpc_predict(solver, p, x0, 1e-5 * jnp.ones((T, d)))
assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
assert_close(X[0], x0)
assert_close(U[0] ** 2., x0 ** 2.)
assert_close(X[1:], np.zeros((T, d)))
assert_close(U[1:], np.zeros((T - 1, d)))
assert_close(X[1:], jnp.zeros((T, d)))
assert_close(U[1:], jnp.zeros((T - 1, d)))
if __name__ == '__main__':

View File

@ -80,7 +80,7 @@ from jax import vmap
from jax.experimental import optimizers
from jax.experimental import stax
from jax.lax import stop_gradient
import jax.numpy as np
import jax.numpy as jnp
from examples import datasets
import numpy.random as npr
@ -124,14 +124,14 @@ def loss(params, batch):
inputs, targets = batch
logits = predict(params, inputs)
logits = stax.logsoftmax(logits) # log normalize
return -np.mean(np.sum(logits * targets, axis=1)) # cross entropy loss
return -jnp.mean(jnp.sum(logits * targets, axis=1)) # cross entropy loss
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
@ -143,9 +143,9 @@ def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
grads = grad(loss)(params, single_example_batch)
nonempty_grads, tree_def = tree_util.tree_flatten(grads)
total_grad_norm = np.linalg.norm(
[np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
divisor = stop_gradient(np.amax((total_grad_norm / l2_norm_clip, 1.)))
total_grad_norm = jnp.linalg.norm(
[jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads])
divisor = stop_gradient(jnp.amax((total_grad_norm / l2_norm_clip, 1.)))
normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads)
@ -154,7 +154,7 @@ def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
normalize_ = lambda n: n / float(batch_size)
tree_map = tree_util.tree_map
sum_ = lambda n: np.sum(n, 0) # aggregate
sum_ = lambda n: jnp.sum(n, 0) # aggregate
aggregated_clipped_grads = tree_map(sum_, px_clipped_grad_fn(batch))
noised_aggregated_clipped_grads = tree_map(noise_, aggregated_clipped_grads)
normalized_noised_aggregated_clipped_grads = (
@ -165,14 +165,14 @@ def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
def shape_as_image(images, labels, dummy_dim=False):
target_shape = (-1, 1, 28, 28, 1) if dummy_dim else (-1, 28, 28, 1)
return np.reshape(images, target_shape), labels
return jnp.reshape(images, target_shape), labels
def compute_epsilon(steps, num_examples=60000, target_delta=1e-5):
if num_examples * target_delta > 1.:
warnings.warn('Your delta might be too high.')
q = FLAGS.batch_size / float(num_examples)
orders = list(np.linspace(1.1, 10.9, 99)) + range(11, 64)
orders = list(jnp.linspace(1.1, 10.9, 99)) + range(11, 64)
rdp_const = compute_rdp(q, FLAGS.noise_multiplier, steps, orders)
eps, _, _ = get_privacy_spent(orders, rdp_const, target_delta=target_delta)
return eps

View File

@ -19,11 +19,11 @@ import sys
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
import numpy as np
from jax import test_util as jtu
from jax import random
import jax.numpy as np
import jax.numpy as jnp
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from examples import kernel_lsq
@ -38,7 +38,7 @@ FLAGS = config.FLAGS
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
jax_rng = random.PRNGKey(0)
result_shape, params = init_fun(jax_rng, input_shape)
rng = onp.random.RandomState(0)
rng = np.random.RandomState(0)
result = apply_fun(params, rng.randn(*input_shape).astype(dtype="float32"))
test_case.assertEqual(result.shape, result_shape)
@ -76,21 +76,21 @@ class ExamplesTest(jtu.JaxTestCase):
def testKernelRegressionGram(self):
n, d = 100, 20
rng = onp.random.RandomState(0)
rng = np.random.RandomState(0)
truth = rng.randn(d)
xs = rng.randn(n, d)
ys = np.dot(xs, truth)
kernel = lambda x, y: np.dot(x, y)
self.assertAllClose(kernel_lsq.gram(kernel, xs), np.dot(xs, xs.T),
ys = jnp.dot(xs, truth)
kernel = lambda x, y: jnp.dot(x, y)
self.assertAllClose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T),
check_dtypes=False)
def testKernelRegressionTrainAndPredict(self):
n, d = 100, 20
rng = onp.random.RandomState(0)
rng = np.random.RandomState(0)
truth = rng.randn(d)
xs = rng.randn(n, d)
ys = np.dot(xs, truth)
kernel = lambda x, y: np.dot(x, y)
ys = jnp.dot(xs, truth)
kernel = lambda x, y: jnp.dot(x, y)
predict = kernel_lsq.train(kernel, xs, ys)
self.assertAllClose(predict(xs), ys, atol=1e-3, rtol=1e-3,
check_dtypes=False)

View File

@ -22,7 +22,7 @@ from jax import grad
from jax import jit
from jax import vmap
from jax.config import config
import jax.numpy as np
import jax.numpy as jnp
import jax.random as random
import jax.scipy as scipy
import matplotlib.pyplot as plt
@ -34,7 +34,7 @@ def main(unused_argv):
numpts = 7
key = random.PRNGKey(0)
eye = np.eye(numpts)
eye = jnp.eye(numpts)
def cov_map(cov_func, xs, xs2=None):
"""Compute a covariance matrix from a covariance function and data points.
@ -51,19 +51,19 @@ def main(unused_argv):
return vmap(lambda x: vmap(lambda y: cov_func(x, y))(xs))(xs2).T
def softplus(x):
return np.logaddexp(x, 0.)
return jnp.logaddexp(x, 0.)
# Note, writing out the vectorized form of the identity
# ||x-y||^2 = <x-y,x-y> = ||x||^2 + ||y||^2 - 2<x,y>
# for computing squared distances would be more efficient (but less succinct).
def exp_quadratic(x1, x2):
return np.exp(-np.sum((x1 - x2)**2))
return jnp.exp(-jnp.sum((x1 - x2)**2))
def gp(params, x, y, xtest=None, compute_marginal_likelihood=False):
noise = softplus(params['noise'])
amp = softplus(params['amplitude'])
ls = softplus(params['lengthscale'])
ymean = np.mean(y)
ymean = jnp.mean(y)
y = y - ymean
x = x / ls
train_cov = amp*cov_map(exp_quadratic, x) + eye * (noise + 1e-6)
@ -71,20 +71,20 @@ def main(unused_argv):
kinvy = scipy.linalg.solve_triangular(
chol.T, scipy.linalg.solve_triangular(chol, y, lower=True))
if compute_marginal_likelihood:
log2pi = np.log(2. * 3.1415)
ml = np.sum(
-0.5 * np.dot(y.T, kinvy) -
np.sum(np.log(np.diag(chol))) -
log2pi = jnp.log(2. * 3.1415)
ml = jnp.sum(
-0.5 * jnp.dot(y.T, kinvy) -
jnp.sum(jnp.log(jnp.diag(chol))) -
(numpts / 2.) * log2pi)
ml -= np.sum(-0.5 * np.log(2 * 3.1415) - np.log(amp)**2) # lognormal prior
ml -= jnp.sum(-0.5 * jnp.log(2 * 3.1415) - jnp.log(amp)**2) # lognormal prior
return -ml
if xtest is not None:
xtest = xtest / ls
cross_cov = amp*cov_map(exp_quadratic, x, xtest)
mu = np.dot(cross_cov.T, kinvy) + ymean
mu = jnp.dot(cross_cov.T, kinvy) + ymean
v = scipy.linalg.solve_triangular(chol, cross_cov, lower=True)
var = (amp * cov_map(exp_quadratic, xtest) - np.dot(v.T, v))
var = (amp * cov_map(exp_quadratic, xtest) - jnp.dot(v.T, v))
return mu, var
marginal_likelihood = partial(gp, compute_marginal_likelihood=True)
@ -92,9 +92,9 @@ def main(unused_argv):
grad_fun = jit(grad(marginal_likelihood))
# Covariance hyperparameters to be learned
params = {"amplitude": np.zeros((1, 1)),
"noise": np.zeros((1, 1)) - 5.,
"lengthscale": np.zeros((1, 1))}
params = {"amplitude": jnp.zeros((1, 1)),
"noise": jnp.zeros((1, 1)) - 5.,
"lengthscale": jnp.zeros((1, 1))}
momentums = dict([(k, p * 0.) for k, p in params.items()])
scales = dict([(k, p * 0. + 1.) for k, p in params.items()])
@ -104,14 +104,14 @@ def main(unused_argv):
for k in (params):
momentums[k] = 0.9 * momentums[k] + 0.1 * grads[k][0]
scales[k] = 0.9 * scales[k] + 0.1 * grads[k][0]**2
params[k] -= lr * momentums[k]/np.sqrt(scales[k] + 1e-5)
params[k] -= lr * momentums[k]/jnp.sqrt(scales[k] + 1e-5)
return params, momentums, scales
# Create a really simple toy 1D function
y_fun = lambda x: np.sin(x) + 0.1 * random.normal(key, shape=(x.shape[0], 1))
y_fun = lambda x: jnp.sin(x) + 0.1 * random.normal(key, shape=(x.shape[0], 1))
x = (random.uniform(key, shape=(numpts, 1)) * 4.) + 1
y = y_fun(x)
xtest = np.linspace(0, 6., 200)[:, None]
xtest = jnp.linspace(0, 6., 200)[:, None]
ytest = y_fun(xtest)
for i in range(1000):
@ -122,7 +122,7 @@ def main(unused_argv):
print(params)
mu, var = predict(params, x, y, xtest)
std = np.sqrt(np.diag(var))
std = jnp.sqrt(jnp.diag(var))
plt.plot(x, y, "k.")
plt.plot(xtest, mu)
plt.fill_between(xtest.flatten(),

View File

@ -17,7 +17,7 @@ from functools import partial
import numpy.random as npr
import jax.numpy as np
import jax.numpy as jnp
from jax.config import config
from jax.experimental import optimizers
from jax import grad, jit, make_jaxpr, vmap
@ -56,15 +56,15 @@ def train(kernel, xs, ys, regularization=0.01):
n = xs.shape[0]
def objective(v):
risk = .5 * np.sum((np.dot(gram_mat, v) - ys) ** 2.0)
reg = regularization * np.sum(v ** 2.0)
risk = .5 * jnp.sum((jnp.dot(gram_mat, v) - ys) ** 2.0)
reg = regularization * jnp.sum(v ** 2.0)
return risk + reg
v = minimize(objective, np.zeros(n))
v = minimize(objective, jnp.zeros(n))
def predict(x):
prods = vmap(lambda x_: kernel(x, x_))(xs)
return np.sum(v * prods)
return jnp.sum(v * prods)
return jit(vmap(predict))
@ -75,19 +75,19 @@ if __name__ == "__main__":
# linear kernel
linear_kernel = lambda x, y: np.dot(x, y)
linear_kernel = lambda x, y: jnp.dot(x, y)
truth = npr.randn(d)
xs = npr.randn(n, d)
ys = np.dot(xs, truth)
ys = jnp.dot(xs, truth)
predict = train(linear_kernel, xs, ys)
print('MSE:', np.sum((predict(xs) - ys) ** 2.))
print('MSE:', jnp.sum((predict(xs) - ys) ** 2.))
def gram_jaxpr(kernel):
return make_jaxpr(partial(gram, kernel))(xs)
rbf_kernel = lambda x, y: np.exp(-np.sum((x - y) ** 2))
rbf_kernel = lambda x, y: jnp.exp(-jnp.sum((x - y) ** 2))
print()
print('jaxpr of gram(linear_kernel):')

View File

@ -25,7 +25,7 @@ import itertools
import numpy.random as npr
import jax.numpy as np
import jax.numpy as jnp
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
@ -37,13 +37,13 @@ from examples import datasets
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(np.sum(preds * targets, axis=1))
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
Dense(1024), Relu,

View File

@ -25,7 +25,7 @@ import numpy.random as npr
from jax.api import jit, grad
from jax.config import config
from jax.scipy.special import logsumexp
import jax.numpy as np
import jax.numpy as jnp
from examples import datasets
@ -36,23 +36,23 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
def predict(params, inputs):
activations = inputs
for w, b in params[:-1]:
outputs = np.dot(activations, w) + b
activations = np.tanh(outputs)
outputs = jnp.dot(activations, w) + b
activations = jnp.tanh(outputs)
final_w, final_b = params[-1]
logits = np.dot(activations, final_w) + final_b
logits = jnp.dot(activations, final_w) + final_b
return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(np.sum(preds * targets, axis=1))
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
if __name__ == "__main__":

View File

@ -24,7 +24,7 @@ import time
import matplotlib.pyplot as plt
import jax.numpy as np
import jax.numpy as jnp
from jax.config import config
from jax import jit, grad, lax, random
from jax.experimental import optimizers
@ -35,15 +35,15 @@ from examples import datasets
def gaussian_kl(mu, sigmasq):
"""KL divergence from a diagonal Gaussian to the standard Gaussian."""
return -0.5 * np.sum(1. + np.log(sigmasq) - mu**2. - sigmasq)
return -0.5 * jnp.sum(1. + jnp.log(sigmasq) - mu**2. - sigmasq)
def gaussian_sample(rng, mu, sigmasq):
"""Sample a diagonal Gaussian."""
return mu + np.sqrt(sigmasq) * random.normal(rng, mu.shape)
return mu + jnp.sqrt(sigmasq) * random.normal(rng, mu.shape)
def bernoulli_logpdf(logits, x):
"""Bernoulli log pdf of data x given logits."""
return -np.sum(np.logaddexp(0., np.where(x, -1., 1.) * logits))
return -jnp.sum(jnp.logaddexp(0., jnp.where(x, -1., 1.) * logits))
def elbo(rng, params, images):
"""Monte Carlo estimate of the negative evidence lower bound."""
@ -57,13 +57,13 @@ def image_sample(rng, params, nrow, ncol):
_, dec_params = params
code_rng, img_rng = random.split(rng)
logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10)))
sampled_images = random.bernoulli(img_rng, np.logaddexp(0., logits))
sampled_images = random.bernoulli(img_rng, jnp.logaddexp(0., logits))
return image_grid(nrow, ncol, sampled_images, (28, 28))
def image_grid(nrow, ncol, imagevecs, imshape):
"""Reshape a stack of image vectors into an image grid for plotting."""
images = iter(imagevecs.reshape((-1,) + imshape))
return np.vstack([np.hstack([next(images).T for _ in range(ncol)][::-1])
return jnp.vstack([jnp.hstack([next(images).T for _ in range(ncol)][::-1])
for _ in range(nrow)]).T

View File

@ -20,7 +20,7 @@ optimization library.
import numpy.random as npr
import jax.numpy as np
import jax.numpy as jnp
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
@ -96,20 +96,20 @@ if __name__ == "__main__":
def loss(params, batch):
inputs, targets = batch
logits = predict_fun(params, inputs)
return -np.sum(logits * targets)
return -jnp.sum(logits * targets)
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=-1)
predicted_class = np.argmax(predict_fun(params, inputs), axis=-1)
return np.mean(predicted_class == target_class)
target_class = jnp.argmax(targets, axis=-1)
predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1)
return jnp.mean(predicted_class == target_class)
def synth_batches():
rng = npr.RandomState(0)
while True:
images = rng.rand(*input_shape).astype('float32')
labels = rng.randint(num_classes, size=(batch_size, 1))
onehot_labels = labels == np.arange(num_classes)
onehot_labels = labels == jnp.arange(num_classes)
yield images, onehot_labels
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)

View File

@ -24,7 +24,7 @@ optimizers libraries.
from functools import partial
import time
import numpy as onp
import numpy as np
import numpy.random as npr
from jax import jit, grad, pmap
@ -33,7 +33,7 @@ from jax.scipy.special import logsumexp
from jax.lib import xla_bridge
from jax.tree_util import tree_map
from jax import lax
import jax.numpy as np
import jax.numpy as jnp
from examples import datasets
@ -44,24 +44,24 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
def predict(params, inputs):
activations = inputs
for w, b in params[:-1]:
outputs = np.dot(activations, w) + b
activations = np.tanh(outputs)
outputs = jnp.dot(activations, w) + b
activations = jnp.tanh(outputs)
final_w, final_b = params[-1]
logits = np.dot(activations, final_w) + final_b
logits = jnp.dot(activations, final_w) + final_b
return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(np.sum(preds * targets, axis=1))
return -jnp.mean(jnp.sum(preds * targets, axis=1))
@jit
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
if __name__ == "__main__":
@ -110,7 +110,7 @@ if __name__ == "__main__":
# We replicate the parameters so that the constituent arrays have a leading
# dimension of size equal to the number of devices we're pmapping over.
init_params = init_random_params(param_scale, layer_sizes)
replicate_array = lambda x: onp.broadcast_to(x, (num_devices,) + x.shape)
replicate_array = lambda x: np.broadcast_to(x, (num_devices,) + x.shape)
replicated_params = tree_map(replicate_array, init_params)
for epoch in range(num_epochs):