[x64] make examples/control_test compatible with strict dtype promotion

This commit is contained in:
Jake VanderPlas 2022-06-16 16:20:54 -07:00
parent f88a1efbb4
commit 1a995a0c61
2 changed files with 12 additions and 12 deletions

View File

@ -83,7 +83,7 @@ def trajectory(dynamics, U, x0):
T, _ = U.shape
d, = x0.shape
X = jnp.zeros((T + 1, d))
X = jnp.zeros((T + 1, d), dtype=x0.dtype)
X = X.at[0].set(x0)
def loop(t, X):

View File

@ -104,7 +104,7 @@ class ControlExampleTest(parameterized.TestCase):
U = jnp.ones((2 * T, 1))
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T)))
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T, dtype=float)))
expected = jnp.reshape(expected, (2 * T + 1, 1))
np.testing.assert_allclose(X, expected)
@ -121,7 +121,7 @@ class ControlExampleTest(parameterized.TestCase):
def dynamics(t, x, u):
'''moves the next standard basis vector'''
idx = (position(x) + u[0]) % num_states
return lax.dynamic_slice_in_dim(jnp.eye(num_states), idx, 1)[0]
return lax.dynamic_slice_in_dim(jnp.eye(num_states, dtype=jnp.int32), idx, 1)[0]
T = 8
@ -143,7 +143,7 @@ class ControlExampleTest(parameterized.TestCase):
def testLqrPredict(self):
dim, T = 2, 10
p = one_step_lqr(dim, T)
x0 = self.rng.normal(size=dim)
x0 = jnp.array(self.rng.normal(size=dim))
X, U = control.lqr_predict(p, x0)
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0,
@ -158,7 +158,7 @@ class ControlExampleTest(parameterized.TestCase):
dim, T, num_iters = 2, 10, 3
lqr = one_step_lqr(dim, T)
p = control_from_lqr(lqr)
x0 = self.rng.normal(size=dim)
x0 = jnp.array(self.rng.normal(size=dim))
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
@ -169,7 +169,7 @@ class ControlExampleTest(parameterized.TestCase):
def testIlqrWithLqrProblemSpecifiedGenerally(self):
dim, T, num_iters = 2, 10, 3
p = one_step_control(dim, T)
x0 = self.rng.normal(size=dim)
x0 = jnp.array(self.rng.normal(size=dim))
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
@ -179,10 +179,10 @@ class ControlExampleTest(parameterized.TestCase):
def testIlqrWithNonlinearProblem(self):
def cost(t, x, u):
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1.)
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1).astype(x.dtype)
def dynamics(t, x, u):
return (x ** 2. - u ** 2.) / (t + 1.)
return (x ** 2. - u ** 2.) / (t + 1).astype(x.dtype)
T, num_iters, d = 10, 7, 1
p = control.ControlSpec(cost, dynamics, T, d, d)
@ -200,7 +200,7 @@ class ControlExampleTest(parameterized.TestCase):
dim, T, num_iters = 2, 10, 3
lqr = one_step_lqr(dim, T)
p = control_from_lqr(lqr)
x0 = self.rng.normal(size=dim)
x0 = jnp.array(self.rng.normal(size=dim))
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
np.testing.assert_allclose(X[0], x0)
@ -212,7 +212,7 @@ class ControlExampleTest(parameterized.TestCase):
def testMpcWithLqrProblemSpecifiedGenerally(self):
dim, T, num_iters = 2, 10, 3
p = one_step_control(dim, T)
x0 = self.rng.normal(size=dim)
x0 = jnp.array(self.rng.normal(size=dim))
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
np.testing.assert_allclose(X[0], x0)
@ -223,10 +223,10 @@ class ControlExampleTest(parameterized.TestCase):
def testMpcWithNonlinearProblem(self):
def cost(t, x, u):
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1.)
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1).astype(x.dtype)
def dynamics(t, x, u):
return (x ** 2. - u ** 2.) / (t + 1.)
return (x ** 2. - u ** 2.) / (t + 1).astype(x.dtype)
T, num_iters, d = 10, 7, 1
p = control.ControlSpec(cost, dynamics, T, d, d)