mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Replace np -> jnp, onp -> np in examples/ (#2971)
For context, see #2370
This commit is contained in:
parent
b543652332
commit
d59ecddfe8
@ -25,7 +25,7 @@ import matplotlib.pyplot as plt
|
||||
from jax.api import jit, grad, vmap
|
||||
from jax import random
|
||||
from jax.experimental import optimizers
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy.stats.norm as norm
|
||||
|
||||
|
||||
@ -33,11 +33,11 @@ import jax.scipy.stats.norm as norm
|
||||
|
||||
def diag_gaussian_sample(rng, mean, log_std):
|
||||
# Take a single sample from a diagonal multivariate Gaussian.
|
||||
return mean + np.exp(log_std) * random.normal(rng, mean.shape)
|
||||
return mean + jnp.exp(log_std) * random.normal(rng, mean.shape)
|
||||
|
||||
def diag_gaussian_logpdf(x, mean, log_std):
|
||||
# Evaluate a single point on a diagonal multivariate Gaussian.
|
||||
return np.sum(vmap(norm.logpdf)(x, mean, np.exp(log_std)))
|
||||
return jnp.sum(vmap(norm.logpdf)(x, mean, jnp.exp(log_std)))
|
||||
|
||||
def elbo(logprob, rng, mean, log_std):
|
||||
# Single-sample Monte Carlo estimate of the variational lower bound.
|
||||
@ -48,7 +48,7 @@ def batch_elbo(logprob, rng, params, num_samples):
|
||||
# Average over a batch of random samples.
|
||||
rngs = random.split(rng, num_samples)
|
||||
vectorized_elbo = vmap(partial(elbo, logprob), in_axes=(0, None, None))
|
||||
return np.mean(vectorized_elbo(rngs, *params))
|
||||
return jnp.mean(vectorized_elbo(rngs, *params))
|
||||
|
||||
|
||||
# ========= Helper function for plotting. =========
|
||||
@ -56,10 +56,10 @@ def batch_elbo(logprob, rng, params, num_samples):
|
||||
@partial(jit, static_argnums=(0, 1, 2, 4))
|
||||
def _mesh_eval(func, x_limits, y_limits, params, num_ticks):
|
||||
# Evaluate func on a 2D grid defined by x_limits and y_limits.
|
||||
x = np.linspace(*x_limits, num=num_ticks)
|
||||
y = np.linspace(*y_limits, num=num_ticks)
|
||||
X, Y = np.meshgrid(x, y)
|
||||
xy_vec = np.stack([X.ravel(), Y.ravel()]).T
|
||||
x = jnp.linspace(*x_limits, num=num_ticks)
|
||||
y = jnp.linspace(*y_limits, num=num_ticks)
|
||||
X, Y = jnp.meshgrid(x, y)
|
||||
xy_vec = jnp.stack([X.ravel(), Y.ravel()]).T
|
||||
zs = vmap(func, in_axes=(0, None))(xy_vec, params)
|
||||
return X, Y, zs.reshape(X.shape)
|
||||
|
||||
@ -69,7 +69,7 @@ def mesh_eval(func, x_limits, y_limits, params, num_ticks=101):
|
||||
# ========= Define an intractable unnormalized density =========
|
||||
|
||||
def funnel_log_density(params):
|
||||
return norm.logpdf(params[0], 0, np.exp(params[1])) + \
|
||||
return norm.logpdf(params[0], 0, jnp.exp(params[1])) + \
|
||||
norm.logpdf(params[1], 0, 1.35)
|
||||
|
||||
|
||||
@ -88,8 +88,8 @@ if __name__ == "__main__":
|
||||
plt.show(block=False)
|
||||
x_limits = [-2, 2]
|
||||
y_limits = [-4, 2]
|
||||
target_dist = lambda x, _: np.exp(funnel_log_density(x))
|
||||
approx_dist = lambda x, params: np.exp(diag_gaussian_logpdf(x, *params))
|
||||
target_dist = lambda x, _: jnp.exp(funnel_log_density(x))
|
||||
approx_dist = lambda x, params: jnp.exp(diag_gaussian_logpdf(x, *params))
|
||||
|
||||
def callback(params, t):
|
||||
print("Iteration {} lower bound {}".format(t, objective(params, t)))
|
||||
@ -117,8 +117,8 @@ if __name__ == "__main__":
|
||||
|
||||
# Set up optimizer.
|
||||
D = 2
|
||||
init_mean = np.zeros(D)
|
||||
init_std = np.zeros(D)
|
||||
init_mean = jnp.zeros(D)
|
||||
init_std = jnp.zeros(D)
|
||||
init_params = (init_mean, init_std)
|
||||
opt_init, opt_update, get_params = optimizers.momentum(step_size=0.1, mass=0.9)
|
||||
opt_state = opt_init(init_params)
|
||||
|
@ -18,7 +18,7 @@ Model-predictive non-linear control example.
|
||||
import collections
|
||||
|
||||
from jax import lax, grad, jacfwd, jacobian, vmap
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax.ops as jo
|
||||
|
||||
|
||||
@ -52,8 +52,8 @@ ControlSpec = collections.namedtuple(
|
||||
LqrSpec = collections.namedtuple('LqrSpec', 'Q q R r M A B')
|
||||
|
||||
|
||||
dot = np.dot
|
||||
mm = np.matmul
|
||||
dot = jnp.dot
|
||||
mm = jnp.matmul
|
||||
|
||||
|
||||
def mv(mat, vec):
|
||||
@ -75,7 +75,7 @@ def fori_loop(lo, hi, loop, init):
|
||||
def scan_fori_loop(lo, hi, loop, init):
|
||||
def scan_f(x, t):
|
||||
return loop(t, x), ()
|
||||
x, _ = lax.scan(scan_f, init, np.arange(lo, hi))
|
||||
x, _ = lax.scan(scan_f, init, jnp.arange(lo, hi))
|
||||
return x
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ def trajectory(dynamics, U, x0):
|
||||
T, _ = U.shape
|
||||
d, = x0.shape
|
||||
|
||||
X = np.zeros((T + 1, d))
|
||||
X = jnp.zeros((T + 1, d))
|
||||
X = jo.index_update(X, jo.index[0], x0)
|
||||
|
||||
def loop(t, X):
|
||||
@ -110,8 +110,8 @@ def make_lqr_approx(p):
|
||||
|
||||
def approx(X, U):
|
||||
assert X.shape[0] == T + 1 and U.shape[0] == T
|
||||
U_pad = np.vstack((U, np.zeros((1,) + U.shape[1:])))
|
||||
Q, q, R, r, M, A, B = _approx(np.arange(T + 1), X, U_pad)
|
||||
U_pad = jnp.vstack((U, jnp.zeros((1,) + U.shape[1:])))
|
||||
Q, q, R, r, M, A, B = _approx(jnp.arange(T + 1), X, U_pad)
|
||||
return LqrSpec(Q, q, R[:T], r[:T], M[:T], A[:T], B[:T])
|
||||
|
||||
return approx
|
||||
@ -122,8 +122,8 @@ def lqr_solve(spec):
|
||||
T, control_dim, _ = spec.R.shape
|
||||
_, state_dim, _ = spec.Q.shape
|
||||
|
||||
K = np.zeros((T, control_dim, state_dim))
|
||||
k = np.zeros((T, control_dim))
|
||||
K = jnp.zeros((T, control_dim, state_dim))
|
||||
k = jnp.zeros((T, control_dim))
|
||||
|
||||
def rev_loop(t_, state):
|
||||
t = T - t_ - 1
|
||||
@ -139,8 +139,8 @@ def lqr_solve(spec):
|
||||
G = R + mm(BtP, B)
|
||||
H = mm(BtP, A) + M.T
|
||||
h = r + mv(B.T, p)
|
||||
K_ = -np.linalg.solve(G + EPS * np.eye(G.shape[0]), H)
|
||||
k_ = -np.linalg.solve(G + EPS * np.eye(G.shape[0]), h)
|
||||
K_ = -jnp.linalg.solve(G + EPS * jnp.eye(G.shape[0]), H)
|
||||
k_ = -jnp.linalg.solve(G + EPS * jnp.eye(G.shape[0]), h)
|
||||
P_ = Q + mm(AtP, A) + mm(K_.T, H)
|
||||
p_ = q + mv(A.T, p) + mv(K_.T, h)
|
||||
|
||||
@ -170,8 +170,8 @@ def lqr_predict(spec, x0):
|
||||
U = jo.index_update(U, jo.index[t], u)
|
||||
return spec, X, U
|
||||
|
||||
U = np.zeros((T, control_dim))
|
||||
X = np.zeros((T + 1, state_dim))
|
||||
U = jnp.zeros((T, control_dim))
|
||||
X = jnp.zeros((T + 1, state_dim))
|
||||
X = jo.index_update(X, jo.index[0], x0)
|
||||
_, X, U = fori_loop(0, T, fwd_loop, (spec, X, U))
|
||||
return X, U
|
||||
@ -186,7 +186,7 @@ def ilqr(iterations, p, x0, U):
|
||||
def loop(_, state):
|
||||
X, U = state
|
||||
p_lqr = lqr_approx(X, U)
|
||||
dX, dU = lqr_predict(p_lqr, np.zeros_like(x0))
|
||||
dX, dU = lqr_predict(p_lqr, jnp.zeros_like(x0))
|
||||
U = U + dU
|
||||
X = trajectory(p.dynamics, U, X[0] + dX[0])
|
||||
return X, U
|
||||
@ -200,7 +200,7 @@ def mpc_predict(solver, p, x0, U):
|
||||
T = p.horizon
|
||||
|
||||
def zero_padded_controls_window(U, t):
|
||||
U_pad = np.vstack((U, np.zeros(U.shape)))
|
||||
U_pad = jnp.vstack((U, jnp.zeros(U.shape)))
|
||||
return lax.dynamic_slice_in_dim(U_pad, t, T, axis=0)
|
||||
|
||||
def loop(t, state):
|
||||
@ -218,6 +218,6 @@ def mpc_predict(solver, p, x0, U):
|
||||
U = jo.index_update(U, jo.index[t], ut)
|
||||
return X, U
|
||||
|
||||
X = np.zeros((T + 1, p.state_dim))
|
||||
X = jnp.zeros((T + 1, p.state_dim))
|
||||
X = jo.index_update(X, jo.index[0], x0)
|
||||
return fori_loop(0, T, loop, (X, U))
|
||||
|
@ -16,11 +16,11 @@ from functools import partial
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
from examples import control
|
||||
|
||||
@ -30,19 +30,19 @@ FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def one_step_lqr(dim, T):
|
||||
Q = np.stack(T * (np.eye(dim),))
|
||||
q = np.zeros((T, dim))
|
||||
R = np.zeros((T, dim, dim))
|
||||
r = np.zeros((T, dim))
|
||||
M = np.zeros((T, dim, dim))
|
||||
A = np.stack(T * (np.eye(dim),))
|
||||
B = np.stack(T * (np.eye(dim),))
|
||||
Q = jnp.stack(T * (jnp.eye(dim),))
|
||||
q = jnp.zeros((T, dim))
|
||||
R = jnp.zeros((T, dim, dim))
|
||||
r = jnp.zeros((T, dim))
|
||||
M = jnp.zeros((T, dim, dim))
|
||||
A = jnp.stack(T * (jnp.eye(dim),))
|
||||
B = jnp.stack(T * (jnp.eye(dim),))
|
||||
return control.LqrSpec(Q, q, R, r, M, A, B)
|
||||
|
||||
|
||||
def control_from_lqr(lqr):
|
||||
T, dim, _ = lqr.Q.shape
|
||||
dot = np.dot
|
||||
dot = jnp.dot
|
||||
|
||||
def cost(t, x, u):
|
||||
return (
|
||||
@ -59,7 +59,7 @@ def control_from_lqr(lqr):
|
||||
def one_step_control(dim, T):
|
||||
|
||||
def cost(t, x, u):
|
||||
return np.dot(x, x)
|
||||
return jnp.dot(x, x)
|
||||
|
||||
def dynamics(t, x, u):
|
||||
return x + u
|
||||
@ -77,33 +77,33 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
|
||||
T = 10
|
||||
|
||||
U = np.ones((T, 1))
|
||||
X = control.trajectory(dynamics, U, np.zeros(1))
|
||||
expected = np.arange(T + 1) % num_states
|
||||
expected = np.reshape(expected, (T + 1, 1))
|
||||
U = jnp.ones((T, 1))
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.arange(T + 1) % num_states
|
||||
expected = jnp.reshape(expected, (T + 1, 1))
|
||||
self.assertAllClose(X, expected, check_dtypes=False)
|
||||
|
||||
U = 2 * np.ones((T, 1))
|
||||
X = control.trajectory(dynamics, U, np.zeros(1))
|
||||
expected = np.cumsum(2 * np.ones(T)) % num_states
|
||||
expected = np.concatenate((np.zeros(1), expected))
|
||||
expected = np.reshape(expected, (T + 1, 1))
|
||||
U = 2 * jnp.ones((T, 1))
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.cumsum(2 * jnp.ones(T)) % num_states
|
||||
expected = jnp.concatenate((jnp.zeros(1), expected))
|
||||
expected = jnp.reshape(expected, (T + 1, 1))
|
||||
self.assertAllClose(X, expected, check_dtypes=False)
|
||||
|
||||
def testTrajectoryTimeVarying(self):
|
||||
T = 6
|
||||
|
||||
def clip(x, lo, hi):
|
||||
return np.minimum(hi, np.maximum(lo, x))
|
||||
return jnp.minimum(hi, jnp.maximum(lo, x))
|
||||
|
||||
def dynamics(t, x, u):
|
||||
# computes `(x + u) if t > T else 0`
|
||||
return (x + u) * clip(t - T, 0, 1)
|
||||
|
||||
U = np.ones((2 * T, 1))
|
||||
X = control.trajectory(dynamics, U, np.zeros(1))
|
||||
expected = np.concatenate((np.zeros(T + 1), np.arange(T)))
|
||||
expected = np.reshape(expected, (2 * T + 1, 1))
|
||||
U = jnp.ones((2 * T, 1))
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T)))
|
||||
expected = jnp.reshape(expected, (2 * T + 1, 1))
|
||||
self.assertAllClose(X, expected, check_dtypes=True)
|
||||
|
||||
|
||||
@ -112,20 +112,20 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
|
||||
def position(x):
|
||||
'''finds the index of a standard basis vector, e.g. [0, 1, 0] -> 1'''
|
||||
x = np.cumsum(x)
|
||||
x = jnp.cumsum(x)
|
||||
x = 1 - x
|
||||
return np.sum(x, dtype=np.int32)
|
||||
return jnp.sum(x, dtype=jnp.int32)
|
||||
|
||||
def dynamics(t, x, u):
|
||||
'''moves the next standard basis vector'''
|
||||
idx = (position(x) + u[0]) % num_states
|
||||
return lax.dynamic_slice_in_dim(np.eye(num_states), idx, 1)[0]
|
||||
return lax.dynamic_slice_in_dim(jnp.eye(num_states), idx, 1)[0]
|
||||
|
||||
T = 8
|
||||
|
||||
U = np.ones((T, 1), dtype=np.int32)
|
||||
X = control.trajectory(dynamics, U, np.eye(num_states, dtype=np.int32)[0])
|
||||
expected = np.vstack((np.eye(num_states),) * 3)
|
||||
U = jnp.ones((T, 1), dtype=jnp.int32)
|
||||
X = control.trajectory(dynamics, U, jnp.eye(num_states, dtype=jnp.int32)[0])
|
||||
expected = jnp.vstack((jnp.eye(num_states),) * 3)
|
||||
self.assertAllClose(X, expected, check_dtypes=True)
|
||||
|
||||
|
||||
@ -133,13 +133,13 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
dim, T = 2, 10
|
||||
p = one_step_lqr(dim, T)
|
||||
K, k = control.lqr_solve(p)
|
||||
K_ = -np.stack(T * (np.eye(dim),))
|
||||
K_ = -jnp.stack(T * (jnp.eye(dim),))
|
||||
self.assertAllClose(K, K_, check_dtypes=True, atol=1e-6, rtol=1e-6)
|
||||
self.assertAllClose(k, np.zeros((T, dim)), check_dtypes=True)
|
||||
self.assertAllClose(k, jnp.zeros((T, dim)), check_dtypes=True)
|
||||
|
||||
|
||||
def testLqrPredict(self):
|
||||
randn = onp.random.RandomState(0).randn
|
||||
randn = np.random.RandomState(0).randn
|
||||
dim, T = 2, 10
|
||||
p = one_step_lqr(dim, T)
|
||||
x0 = randn(dim)
|
||||
@ -147,35 +147,35 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True,
|
||||
atol=1e-6, rtol=1e-6)
|
||||
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True,
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True,
|
||||
atol=1e-6, rtol=1e-6)
|
||||
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True,
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True,
|
||||
atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
def testIlqrWithLqrProblem(self):
|
||||
randn = onp.random.RandomState(0).randn
|
||||
randn = np.random.RandomState(0).randn
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
lqr = one_step_lqr(dim, T)
|
||||
p = control_from_lqr(lqr)
|
||||
x0 = randn(dim)
|
||||
X, U = control.ilqr(num_iters, p, x0, np.zeros((T, dim)))
|
||||
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True)
|
||||
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
|
||||
|
||||
|
||||
def testIlqrWithLqrProblemSpecifiedGenerally(self):
|
||||
randn = onp.random.RandomState(0).randn
|
||||
randn = np.random.RandomState(0).randn
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
p = one_step_control(dim, T)
|
||||
x0 = randn(dim)
|
||||
X, U = control.ilqr(num_iters, p, x0, np.zeros((T, dim)))
|
||||
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True)
|
||||
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
|
||||
|
||||
|
||||
def testIlqrWithNonlinearProblem(self):
|
||||
@ -188,40 +188,40 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
T, num_iters, d = 10, 7, 1
|
||||
p = control.ControlSpec(cost, dynamics, T, d, d)
|
||||
|
||||
x0 = np.array([0.2])
|
||||
X, U = control.ilqr(num_iters, p, x0, 1e-5 * np.ones((T, d)))
|
||||
x0 = jnp.array([0.2])
|
||||
X, U = control.ilqr(num_iters, p, x0, 1e-5 * jnp.ones((T, d)))
|
||||
assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
|
||||
assert_close(X[0], x0)
|
||||
assert_close(U[0] ** 2., x0 ** 2.)
|
||||
assert_close(X[1:], np.zeros((T, d)))
|
||||
assert_close(U[1:], np.zeros((T - 1, d)))
|
||||
assert_close(X[1:], jnp.zeros((T, d)))
|
||||
assert_close(U[1:], jnp.zeros((T - 1, d)))
|
||||
|
||||
|
||||
def testMpcWithLqrProblem(self):
|
||||
randn = onp.random.RandomState(0).randn
|
||||
randn = np.random.RandomState(0).randn
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
lqr = one_step_lqr(dim, T)
|
||||
p = control_from_lqr(lqr)
|
||||
x0 = randn(dim)
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, np.zeros((T, dim)))
|
||||
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True)
|
||||
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
|
||||
|
||||
|
||||
def testMpcWithLqrProblemSpecifiedGenerally(self):
|
||||
randn = onp.random.RandomState(0).randn
|
||||
randn = np.random.RandomState(0).randn
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
p = one_step_control(dim, T)
|
||||
x0 = randn(dim)
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, np.zeros((T, dim)))
|
||||
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True)
|
||||
self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
|
||||
|
||||
|
||||
def testMpcWithNonlinearProblem(self):
|
||||
@ -234,14 +234,14 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
T, num_iters, d = 10, 7, 1
|
||||
p = control.ControlSpec(cost, dynamics, T, d, d)
|
||||
|
||||
x0 = np.array([0.2])
|
||||
x0 = jnp.array([0.2])
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, 1e-5 * np.ones((T, d)))
|
||||
X, U = control.mpc_predict(solver, p, x0, 1e-5 * jnp.ones((T, d)))
|
||||
assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
|
||||
assert_close(X[0], x0)
|
||||
assert_close(U[0] ** 2., x0 ** 2.)
|
||||
assert_close(X[1:], np.zeros((T, d)))
|
||||
assert_close(U[1:], np.zeros((T - 1, d)))
|
||||
assert_close(X[1:], jnp.zeros((T, d)))
|
||||
assert_close(U[1:], jnp.zeros((T - 1, d)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -80,7 +80,7 @@ from jax import vmap
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.lax import stop_gradient
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from examples import datasets
|
||||
import numpy.random as npr
|
||||
|
||||
@ -124,14 +124,14 @@ def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
logits = predict(params, inputs)
|
||||
logits = stax.logsoftmax(logits) # log normalize
|
||||
return -np.mean(np.sum(logits * targets, axis=1)) # cross entropy loss
|
||||
return -jnp.mean(jnp.sum(logits * targets, axis=1)) # cross entropy loss
|
||||
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=1)
|
||||
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
target_class = jnp.argmax(targets, axis=1)
|
||||
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
|
||||
return jnp.mean(predicted_class == target_class)
|
||||
|
||||
|
||||
def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
|
||||
@ -143,9 +143,9 @@ def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
|
||||
grads = grad(loss)(params, single_example_batch)
|
||||
|
||||
nonempty_grads, tree_def = tree_util.tree_flatten(grads)
|
||||
total_grad_norm = np.linalg.norm(
|
||||
[np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
|
||||
divisor = stop_gradient(np.amax((total_grad_norm / l2_norm_clip, 1.)))
|
||||
total_grad_norm = jnp.linalg.norm(
|
||||
[jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads])
|
||||
divisor = stop_gradient(jnp.amax((total_grad_norm / l2_norm_clip, 1.)))
|
||||
normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
|
||||
return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads)
|
||||
|
||||
@ -154,7 +154,7 @@ def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
|
||||
noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
|
||||
normalize_ = lambda n: n / float(batch_size)
|
||||
tree_map = tree_util.tree_map
|
||||
sum_ = lambda n: np.sum(n, 0) # aggregate
|
||||
sum_ = lambda n: jnp.sum(n, 0) # aggregate
|
||||
aggregated_clipped_grads = tree_map(sum_, px_clipped_grad_fn(batch))
|
||||
noised_aggregated_clipped_grads = tree_map(noise_, aggregated_clipped_grads)
|
||||
normalized_noised_aggregated_clipped_grads = (
|
||||
@ -165,14 +165,14 @@ def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
|
||||
|
||||
def shape_as_image(images, labels, dummy_dim=False):
|
||||
target_shape = (-1, 1, 28, 28, 1) if dummy_dim else (-1, 28, 28, 1)
|
||||
return np.reshape(images, target_shape), labels
|
||||
return jnp.reshape(images, target_shape), labels
|
||||
|
||||
|
||||
def compute_epsilon(steps, num_examples=60000, target_delta=1e-5):
|
||||
if num_examples * target_delta > 1.:
|
||||
warnings.warn('Your delta might be too high.')
|
||||
q = FLAGS.batch_size / float(num_examples)
|
||||
orders = list(np.linspace(1.1, 10.9, 99)) + range(11, 64)
|
||||
orders = list(jnp.linspace(1.1, 10.9, 99)) + range(11, 64)
|
||||
rdp_const = compute_rdp(q, FLAGS.noise_multiplier, steps, orders)
|
||||
eps, _, _ = get_privacy_spent(orders, rdp_const, target_delta=target_delta)
|
||||
return eps
|
||||
|
@ -19,11 +19,11 @@ import sys
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from jax import test_util as jtu
|
||||
from jax import random
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from examples import kernel_lsq
|
||||
@ -38,7 +38,7 @@ FLAGS = config.FLAGS
|
||||
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
||||
jax_rng = random.PRNGKey(0)
|
||||
result_shape, params = init_fun(jax_rng, input_shape)
|
||||
rng = onp.random.RandomState(0)
|
||||
rng = np.random.RandomState(0)
|
||||
result = apply_fun(params, rng.randn(*input_shape).astype(dtype="float32"))
|
||||
test_case.assertEqual(result.shape, result_shape)
|
||||
|
||||
@ -76,21 +76,21 @@ class ExamplesTest(jtu.JaxTestCase):
|
||||
|
||||
def testKernelRegressionGram(self):
|
||||
n, d = 100, 20
|
||||
rng = onp.random.RandomState(0)
|
||||
rng = np.random.RandomState(0)
|
||||
truth = rng.randn(d)
|
||||
xs = rng.randn(n, d)
|
||||
ys = np.dot(xs, truth)
|
||||
kernel = lambda x, y: np.dot(x, y)
|
||||
self.assertAllClose(kernel_lsq.gram(kernel, xs), np.dot(xs, xs.T),
|
||||
ys = jnp.dot(xs, truth)
|
||||
kernel = lambda x, y: jnp.dot(x, y)
|
||||
self.assertAllClose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T),
|
||||
check_dtypes=False)
|
||||
|
||||
def testKernelRegressionTrainAndPredict(self):
|
||||
n, d = 100, 20
|
||||
rng = onp.random.RandomState(0)
|
||||
rng = np.random.RandomState(0)
|
||||
truth = rng.randn(d)
|
||||
xs = rng.randn(n, d)
|
||||
ys = np.dot(xs, truth)
|
||||
kernel = lambda x, y: np.dot(x, y)
|
||||
ys = jnp.dot(xs, truth)
|
||||
kernel = lambda x, y: jnp.dot(x, y)
|
||||
predict = kernel_lsq.train(kernel, xs, ys)
|
||||
self.assertAllClose(predict(xs), ys, atol=1e-3, rtol=1e-3,
|
||||
check_dtypes=False)
|
||||
|
@ -22,7 +22,7 @@ from jax import grad
|
||||
from jax import jit
|
||||
from jax import vmap
|
||||
from jax.config import config
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax.random as random
|
||||
import jax.scipy as scipy
|
||||
import matplotlib.pyplot as plt
|
||||
@ -34,7 +34,7 @@ def main(unused_argv):
|
||||
|
||||
numpts = 7
|
||||
key = random.PRNGKey(0)
|
||||
eye = np.eye(numpts)
|
||||
eye = jnp.eye(numpts)
|
||||
|
||||
def cov_map(cov_func, xs, xs2=None):
|
||||
"""Compute a covariance matrix from a covariance function and data points.
|
||||
@ -51,19 +51,19 @@ def main(unused_argv):
|
||||
return vmap(lambda x: vmap(lambda y: cov_func(x, y))(xs))(xs2).T
|
||||
|
||||
def softplus(x):
|
||||
return np.logaddexp(x, 0.)
|
||||
return jnp.logaddexp(x, 0.)
|
||||
|
||||
# Note, writing out the vectorized form of the identity
|
||||
# ||x-y||^2 = <x-y,x-y> = ||x||^2 + ||y||^2 - 2<x,y>
|
||||
# for computing squared distances would be more efficient (but less succinct).
|
||||
def exp_quadratic(x1, x2):
|
||||
return np.exp(-np.sum((x1 - x2)**2))
|
||||
return jnp.exp(-jnp.sum((x1 - x2)**2))
|
||||
|
||||
def gp(params, x, y, xtest=None, compute_marginal_likelihood=False):
|
||||
noise = softplus(params['noise'])
|
||||
amp = softplus(params['amplitude'])
|
||||
ls = softplus(params['lengthscale'])
|
||||
ymean = np.mean(y)
|
||||
ymean = jnp.mean(y)
|
||||
y = y - ymean
|
||||
x = x / ls
|
||||
train_cov = amp*cov_map(exp_quadratic, x) + eye * (noise + 1e-6)
|
||||
@ -71,20 +71,20 @@ def main(unused_argv):
|
||||
kinvy = scipy.linalg.solve_triangular(
|
||||
chol.T, scipy.linalg.solve_triangular(chol, y, lower=True))
|
||||
if compute_marginal_likelihood:
|
||||
log2pi = np.log(2. * 3.1415)
|
||||
ml = np.sum(
|
||||
-0.5 * np.dot(y.T, kinvy) -
|
||||
np.sum(np.log(np.diag(chol))) -
|
||||
log2pi = jnp.log(2. * 3.1415)
|
||||
ml = jnp.sum(
|
||||
-0.5 * jnp.dot(y.T, kinvy) -
|
||||
jnp.sum(jnp.log(jnp.diag(chol))) -
|
||||
(numpts / 2.) * log2pi)
|
||||
ml -= np.sum(-0.5 * np.log(2 * 3.1415) - np.log(amp)**2) # lognormal prior
|
||||
ml -= jnp.sum(-0.5 * jnp.log(2 * 3.1415) - jnp.log(amp)**2) # lognormal prior
|
||||
return -ml
|
||||
|
||||
if xtest is not None:
|
||||
xtest = xtest / ls
|
||||
cross_cov = amp*cov_map(exp_quadratic, x, xtest)
|
||||
mu = np.dot(cross_cov.T, kinvy) + ymean
|
||||
mu = jnp.dot(cross_cov.T, kinvy) + ymean
|
||||
v = scipy.linalg.solve_triangular(chol, cross_cov, lower=True)
|
||||
var = (amp * cov_map(exp_quadratic, xtest) - np.dot(v.T, v))
|
||||
var = (amp * cov_map(exp_quadratic, xtest) - jnp.dot(v.T, v))
|
||||
return mu, var
|
||||
|
||||
marginal_likelihood = partial(gp, compute_marginal_likelihood=True)
|
||||
@ -92,9 +92,9 @@ def main(unused_argv):
|
||||
grad_fun = jit(grad(marginal_likelihood))
|
||||
|
||||
# Covariance hyperparameters to be learned
|
||||
params = {"amplitude": np.zeros((1, 1)),
|
||||
"noise": np.zeros((1, 1)) - 5.,
|
||||
"lengthscale": np.zeros((1, 1))}
|
||||
params = {"amplitude": jnp.zeros((1, 1)),
|
||||
"noise": jnp.zeros((1, 1)) - 5.,
|
||||
"lengthscale": jnp.zeros((1, 1))}
|
||||
momentums = dict([(k, p * 0.) for k, p in params.items()])
|
||||
scales = dict([(k, p * 0. + 1.) for k, p in params.items()])
|
||||
|
||||
@ -104,14 +104,14 @@ def main(unused_argv):
|
||||
for k in (params):
|
||||
momentums[k] = 0.9 * momentums[k] + 0.1 * grads[k][0]
|
||||
scales[k] = 0.9 * scales[k] + 0.1 * grads[k][0]**2
|
||||
params[k] -= lr * momentums[k]/np.sqrt(scales[k] + 1e-5)
|
||||
params[k] -= lr * momentums[k]/jnp.sqrt(scales[k] + 1e-5)
|
||||
return params, momentums, scales
|
||||
|
||||
# Create a really simple toy 1D function
|
||||
y_fun = lambda x: np.sin(x) + 0.1 * random.normal(key, shape=(x.shape[0], 1))
|
||||
y_fun = lambda x: jnp.sin(x) + 0.1 * random.normal(key, shape=(x.shape[0], 1))
|
||||
x = (random.uniform(key, shape=(numpts, 1)) * 4.) + 1
|
||||
y = y_fun(x)
|
||||
xtest = np.linspace(0, 6., 200)[:, None]
|
||||
xtest = jnp.linspace(0, 6., 200)[:, None]
|
||||
ytest = y_fun(xtest)
|
||||
|
||||
for i in range(1000):
|
||||
@ -122,7 +122,7 @@ def main(unused_argv):
|
||||
|
||||
print(params)
|
||||
mu, var = predict(params, x, y, xtest)
|
||||
std = np.sqrt(np.diag(var))
|
||||
std = jnp.sqrt(jnp.diag(var))
|
||||
plt.plot(x, y, "k.")
|
||||
plt.plot(xtest, mu)
|
||||
plt.fill_between(xtest.flatten(),
|
||||
|
@ -17,7 +17,7 @@ from functools import partial
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.config import config
|
||||
from jax.experimental import optimizers
|
||||
from jax import grad, jit, make_jaxpr, vmap
|
||||
@ -56,15 +56,15 @@ def train(kernel, xs, ys, regularization=0.01):
|
||||
n = xs.shape[0]
|
||||
|
||||
def objective(v):
|
||||
risk = .5 * np.sum((np.dot(gram_mat, v) - ys) ** 2.0)
|
||||
reg = regularization * np.sum(v ** 2.0)
|
||||
risk = .5 * jnp.sum((jnp.dot(gram_mat, v) - ys) ** 2.0)
|
||||
reg = regularization * jnp.sum(v ** 2.0)
|
||||
return risk + reg
|
||||
|
||||
v = minimize(objective, np.zeros(n))
|
||||
v = minimize(objective, jnp.zeros(n))
|
||||
|
||||
def predict(x):
|
||||
prods = vmap(lambda x_: kernel(x, x_))(xs)
|
||||
return np.sum(v * prods)
|
||||
return jnp.sum(v * prods)
|
||||
|
||||
return jit(vmap(predict))
|
||||
|
||||
@ -75,19 +75,19 @@ if __name__ == "__main__":
|
||||
|
||||
# linear kernel
|
||||
|
||||
linear_kernel = lambda x, y: np.dot(x, y)
|
||||
linear_kernel = lambda x, y: jnp.dot(x, y)
|
||||
truth = npr.randn(d)
|
||||
xs = npr.randn(n, d)
|
||||
ys = np.dot(xs, truth)
|
||||
ys = jnp.dot(xs, truth)
|
||||
|
||||
predict = train(linear_kernel, xs, ys)
|
||||
|
||||
print('MSE:', np.sum((predict(xs) - ys) ** 2.))
|
||||
print('MSE:', jnp.sum((predict(xs) - ys) ** 2.))
|
||||
|
||||
def gram_jaxpr(kernel):
|
||||
return make_jaxpr(partial(gram, kernel))(xs)
|
||||
|
||||
rbf_kernel = lambda x, y: np.exp(-np.sum((x - y) ** 2))
|
||||
rbf_kernel = lambda x, y: jnp.exp(-jnp.sum((x - y) ** 2))
|
||||
|
||||
print()
|
||||
print('jaxpr of gram(linear_kernel):')
|
||||
|
@ -25,7 +25,7 @@ import itertools
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.config import config
|
||||
from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
@ -37,13 +37,13 @@ from examples import datasets
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(np.sum(preds * targets, axis=1))
|
||||
return -jnp.mean(jnp.sum(preds * targets, axis=1))
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=1)
|
||||
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
target_class = jnp.argmax(targets, axis=1)
|
||||
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
|
||||
return jnp.mean(predicted_class == target_class)
|
||||
|
||||
init_random_params, predict = stax.serial(
|
||||
Dense(1024), Relu,
|
||||
|
@ -25,7 +25,7 @@ import numpy.random as npr
|
||||
from jax.api import jit, grad
|
||||
from jax.config import config
|
||||
from jax.scipy.special import logsumexp
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from examples import datasets
|
||||
|
||||
|
||||
@ -36,23 +36,23 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
def predict(params, inputs):
|
||||
activations = inputs
|
||||
for w, b in params[:-1]:
|
||||
outputs = np.dot(activations, w) + b
|
||||
activations = np.tanh(outputs)
|
||||
outputs = jnp.dot(activations, w) + b
|
||||
activations = jnp.tanh(outputs)
|
||||
|
||||
final_w, final_b = params[-1]
|
||||
logits = np.dot(activations, final_w) + final_b
|
||||
logits = jnp.dot(activations, final_w) + final_b
|
||||
return logits - logsumexp(logits, axis=1, keepdims=True)
|
||||
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(np.sum(preds * targets, axis=1))
|
||||
return -jnp.mean(jnp.sum(preds * targets, axis=1))
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=1)
|
||||
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
target_class = jnp.argmax(targets, axis=1)
|
||||
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
|
||||
return jnp.mean(predicted_class == target_class)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -24,7 +24,7 @@ import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.config import config
|
||||
from jax import jit, grad, lax, random
|
||||
from jax.experimental import optimizers
|
||||
@ -35,15 +35,15 @@ from examples import datasets
|
||||
|
||||
def gaussian_kl(mu, sigmasq):
|
||||
"""KL divergence from a diagonal Gaussian to the standard Gaussian."""
|
||||
return -0.5 * np.sum(1. + np.log(sigmasq) - mu**2. - sigmasq)
|
||||
return -0.5 * jnp.sum(1. + jnp.log(sigmasq) - mu**2. - sigmasq)
|
||||
|
||||
def gaussian_sample(rng, mu, sigmasq):
|
||||
"""Sample a diagonal Gaussian."""
|
||||
return mu + np.sqrt(sigmasq) * random.normal(rng, mu.shape)
|
||||
return mu + jnp.sqrt(sigmasq) * random.normal(rng, mu.shape)
|
||||
|
||||
def bernoulli_logpdf(logits, x):
|
||||
"""Bernoulli log pdf of data x given logits."""
|
||||
return -np.sum(np.logaddexp(0., np.where(x, -1., 1.) * logits))
|
||||
return -jnp.sum(jnp.logaddexp(0., jnp.where(x, -1., 1.) * logits))
|
||||
|
||||
def elbo(rng, params, images):
|
||||
"""Monte Carlo estimate of the negative evidence lower bound."""
|
||||
@ -57,13 +57,13 @@ def image_sample(rng, params, nrow, ncol):
|
||||
_, dec_params = params
|
||||
code_rng, img_rng = random.split(rng)
|
||||
logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10)))
|
||||
sampled_images = random.bernoulli(img_rng, np.logaddexp(0., logits))
|
||||
sampled_images = random.bernoulli(img_rng, jnp.logaddexp(0., logits))
|
||||
return image_grid(nrow, ncol, sampled_images, (28, 28))
|
||||
|
||||
def image_grid(nrow, ncol, imagevecs, imshape):
|
||||
"""Reshape a stack of image vectors into an image grid for plotting."""
|
||||
images = iter(imagevecs.reshape((-1,) + imshape))
|
||||
return np.vstack([np.hstack([next(images).T for _ in range(ncol)][::-1])
|
||||
return jnp.vstack([jnp.hstack([next(images).T for _ in range(ncol)][::-1])
|
||||
for _ in range(nrow)]).T
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ optimization library.
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.config import config
|
||||
from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
@ -96,20 +96,20 @@ if __name__ == "__main__":
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
logits = predict_fun(params, inputs)
|
||||
return -np.sum(logits * targets)
|
||||
return -jnp.sum(logits * targets)
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=-1)
|
||||
predicted_class = np.argmax(predict_fun(params, inputs), axis=-1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
target_class = jnp.argmax(targets, axis=-1)
|
||||
predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1)
|
||||
return jnp.mean(predicted_class == target_class)
|
||||
|
||||
def synth_batches():
|
||||
rng = npr.RandomState(0)
|
||||
while True:
|
||||
images = rng.rand(*input_shape).astype('float32')
|
||||
labels = rng.randint(num_classes, size=(batch_size, 1))
|
||||
onehot_labels = labels == np.arange(num_classes)
|
||||
onehot_labels = labels == jnp.arange(num_classes)
|
||||
yield images, onehot_labels
|
||||
|
||||
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
|
||||
|
@ -24,7 +24,7 @@ optimizers libraries.
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
import numpy.random as npr
|
||||
|
||||
from jax import jit, grad, pmap
|
||||
@ -33,7 +33,7 @@ from jax.scipy.special import logsumexp
|
||||
from jax.lib import xla_bridge
|
||||
from jax.tree_util import tree_map
|
||||
from jax import lax
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from examples import datasets
|
||||
|
||||
|
||||
@ -44,24 +44,24 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
def predict(params, inputs):
|
||||
activations = inputs
|
||||
for w, b in params[:-1]:
|
||||
outputs = np.dot(activations, w) + b
|
||||
activations = np.tanh(outputs)
|
||||
outputs = jnp.dot(activations, w) + b
|
||||
activations = jnp.tanh(outputs)
|
||||
|
||||
final_w, final_b = params[-1]
|
||||
logits = np.dot(activations, final_w) + final_b
|
||||
logits = jnp.dot(activations, final_w) + final_b
|
||||
return logits - logsumexp(logits, axis=1, keepdims=True)
|
||||
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(np.sum(preds * targets, axis=1))
|
||||
return -jnp.mean(jnp.sum(preds * targets, axis=1))
|
||||
|
||||
@jit
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=1)
|
||||
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
target_class = jnp.argmax(targets, axis=1)
|
||||
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
|
||||
return jnp.mean(predicted_class == target_class)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -110,7 +110,7 @@ if __name__ == "__main__":
|
||||
# We replicate the parameters so that the constituent arrays have a leading
|
||||
# dimension of size equal to the number of devices we're pmapping over.
|
||||
init_params = init_random_params(param_scale, layer_sizes)
|
||||
replicate_array = lambda x: onp.broadcast_to(x, (num_devices,) + x.shape)
|
||||
replicate_array = lambda x: np.broadcast_to(x, (num_devices,) + x.shape)
|
||||
replicated_params = tree_map(replicate_array, init_params)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user