mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[x64] make examples/control_test compatible with strict dtype promotion
This commit is contained in:
parent
f88a1efbb4
commit
1a995a0c61
@ -83,7 +83,7 @@ def trajectory(dynamics, U, x0):
|
||||
T, _ = U.shape
|
||||
d, = x0.shape
|
||||
|
||||
X = jnp.zeros((T + 1, d))
|
||||
X = jnp.zeros((T + 1, d), dtype=x0.dtype)
|
||||
X = X.at[0].set(x0)
|
||||
|
||||
def loop(t, X):
|
||||
|
@ -104,7 +104,7 @@ class ControlExampleTest(parameterized.TestCase):
|
||||
|
||||
U = jnp.ones((2 * T, 1))
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T)))
|
||||
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T, dtype=float)))
|
||||
expected = jnp.reshape(expected, (2 * T + 1, 1))
|
||||
np.testing.assert_allclose(X, expected)
|
||||
|
||||
@ -121,7 +121,7 @@ class ControlExampleTest(parameterized.TestCase):
|
||||
def dynamics(t, x, u):
|
||||
'''moves the next standard basis vector'''
|
||||
idx = (position(x) + u[0]) % num_states
|
||||
return lax.dynamic_slice_in_dim(jnp.eye(num_states), idx, 1)[0]
|
||||
return lax.dynamic_slice_in_dim(jnp.eye(num_states, dtype=jnp.int32), idx, 1)[0]
|
||||
|
||||
T = 8
|
||||
|
||||
@ -143,7 +143,7 @@ class ControlExampleTest(parameterized.TestCase):
|
||||
def testLqrPredict(self):
|
||||
dim, T = 2, 10
|
||||
p = one_step_lqr(dim, T)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
X, U = control.lqr_predict(p, x0)
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0,
|
||||
@ -158,7 +158,7 @@ class ControlExampleTest(parameterized.TestCase):
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
lqr = one_step_lqr(dim, T)
|
||||
p = control_from_lqr(lqr)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
@ -169,7 +169,7 @@ class ControlExampleTest(parameterized.TestCase):
|
||||
def testIlqrWithLqrProblemSpecifiedGenerally(self):
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
p = one_step_control(dim, T)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
@ -179,10 +179,10 @@ class ControlExampleTest(parameterized.TestCase):
|
||||
|
||||
def testIlqrWithNonlinearProblem(self):
|
||||
def cost(t, x, u):
|
||||
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1.)
|
||||
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1).astype(x.dtype)
|
||||
|
||||
def dynamics(t, x, u):
|
||||
return (x ** 2. - u ** 2.) / (t + 1.)
|
||||
return (x ** 2. - u ** 2.) / (t + 1).astype(x.dtype)
|
||||
|
||||
T, num_iters, d = 10, 7, 1
|
||||
p = control.ControlSpec(cost, dynamics, T, d, d)
|
||||
@ -200,7 +200,7 @@ class ControlExampleTest(parameterized.TestCase):
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
lqr = one_step_lqr(dim, T)
|
||||
p = control_from_lqr(lqr)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
@ -212,7 +212,7 @@ class ControlExampleTest(parameterized.TestCase):
|
||||
def testMpcWithLqrProblemSpecifiedGenerally(self):
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
p = one_step_control(dim, T)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
@ -223,10 +223,10 @@ class ControlExampleTest(parameterized.TestCase):
|
||||
|
||||
def testMpcWithNonlinearProblem(self):
|
||||
def cost(t, x, u):
|
||||
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1.)
|
||||
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1).astype(x.dtype)
|
||||
|
||||
def dynamics(t, x, u):
|
||||
return (x ** 2. - u ** 2.) / (t + 1.)
|
||||
return (x ** 2. - u ** 2.) / (t + 1).astype(x.dtype)
|
||||
|
||||
T, num_iters, d = 10, 7, 1
|
||||
p = control.ControlSpec(cost, dynamics, T, d, d)
|
||||
|
Loading…
x
Reference in New Issue
Block a user