mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #11629 from froystig:rm-control-example
PiperOrigin-RevId: 463462047
This commit is contained in:
commit
eb7040d6d6
@ -1,222 +0,0 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Model-predictive non-linear control example.
|
||||
"""
|
||||
|
||||
import collections
|
||||
|
||||
from jax import lax, grad, jacfwd, jacobian, vmap
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
# Specifies a general finite-horizon, time-varying control problem. Given cost
|
||||
# function `c`, transition function `f`, and initial state `x0`, the goal is to
|
||||
# compute:
|
||||
#
|
||||
# argmin(lambda X, U: c(T, X[T]) + sum(c(t, X[t], U[t]) for t in range(T)))
|
||||
#
|
||||
# subject to the constraints that `X[0] == x0` and that:
|
||||
#
|
||||
# all(X[t + 1] == f(X[t], U[t]) for t in range(T)) .
|
||||
#
|
||||
# The special case in which `c` is quadratic and `f` is linear is the
|
||||
# linear-quadratic regulator (LQR) problem, and can be specified explicitly
|
||||
# further below.
|
||||
#
|
||||
ControlSpec = collections.namedtuple(
|
||||
'ControlSpec', 'cost dynamics horizon state_dim control_dim')
|
||||
|
||||
|
||||
# Specifies a finite-horizon, time-varying LQR problem. Notation:
|
||||
#
|
||||
# cost(t, x, u) = sum(
|
||||
# dot(x.T, Q[t], x) + dot(q[t], x) +
|
||||
# dot(u.T, R[t], u) + dot(r[t], u) +
|
||||
# dot(x.T, M[t], u)
|
||||
#
|
||||
# dynamics(t, x, u) = dot(A[t], x) + dot(B[t], u)
|
||||
#
|
||||
LqrSpec = collections.namedtuple('LqrSpec', 'Q q R r M A B')
|
||||
|
||||
|
||||
dot = jnp.dot
|
||||
mm = jnp.matmul
|
||||
|
||||
|
||||
def mv(mat, vec):
|
||||
assert mat.ndim == 2
|
||||
assert vec.ndim == 1
|
||||
return dot(mat, vec)
|
||||
|
||||
|
||||
LOOP_VIA_SCAN = False
|
||||
|
||||
|
||||
def fori_loop(lo, hi, loop, init):
|
||||
if LOOP_VIA_SCAN:
|
||||
return scan_fori_loop(lo, hi, loop, init)
|
||||
else:
|
||||
return lax.fori_loop(lo, hi, loop, init)
|
||||
|
||||
|
||||
def scan_fori_loop(lo, hi, loop, init):
|
||||
def scan_f(x, t):
|
||||
return loop(t, x), ()
|
||||
x, _ = lax.scan(scan_f, init, jnp.arange(lo, hi))
|
||||
return x
|
||||
|
||||
|
||||
def trajectory(dynamics, U, x0):
|
||||
'''Unrolls `X[t+1] = dynamics(t, X[t], U[t])`, where `X[0] = x0`.'''
|
||||
T, _ = U.shape
|
||||
d, = x0.shape
|
||||
|
||||
X = jnp.zeros((T + 1, d), dtype=x0.dtype)
|
||||
X = X.at[0].set(x0)
|
||||
|
||||
def loop(t, X):
|
||||
x = dynamics(t, X[t], U[t])
|
||||
X = X.at[t + 1].set(x)
|
||||
return X
|
||||
|
||||
return fori_loop(0, T, loop, X)
|
||||
|
||||
|
||||
def make_lqr_approx(p):
|
||||
T = p.horizon
|
||||
|
||||
def approx_timestep(t, x, u):
|
||||
M = jacfwd(grad(p.cost, argnums=2), argnums=1)(t, x, u).T
|
||||
Q = jacfwd(grad(p.cost, argnums=1), argnums=1)(t, x, u)
|
||||
R = jacfwd(grad(p.cost, argnums=2), argnums=2)(t, x, u)
|
||||
q, r = grad(p.cost, argnums=(1, 2))(t, x, u)
|
||||
A, B = jacobian(p.dynamics, argnums=(1, 2))(t, x, u)
|
||||
return Q, q, R, r, M, A, B
|
||||
|
||||
_approx = vmap(approx_timestep)
|
||||
|
||||
def approx(X, U):
|
||||
assert X.shape[0] == T + 1 and U.shape[0] == T
|
||||
U_pad = jnp.vstack((U, jnp.zeros((1,) + U.shape[1:])))
|
||||
Q, q, R, r, M, A, B = _approx(jnp.arange(T + 1), X, U_pad)
|
||||
return LqrSpec(Q, q, R[:T], r[:T], M[:T], A[:T], B[:T])
|
||||
|
||||
return approx
|
||||
|
||||
|
||||
def lqr_solve(spec):
|
||||
EPS = 1e-7
|
||||
T, control_dim, _ = spec.R.shape
|
||||
_, state_dim, _ = spec.Q.shape
|
||||
|
||||
K = jnp.zeros((T, control_dim, state_dim))
|
||||
k = jnp.zeros((T, control_dim))
|
||||
|
||||
def rev_loop(t_, state):
|
||||
t = T - t_ - 1
|
||||
spec, P, p, K, k = state
|
||||
|
||||
Q, q = spec.Q[t], spec.q[t]
|
||||
R, r = spec.R[t], spec.r[t]
|
||||
M = spec.M[t]
|
||||
A, B = spec.A[t], spec.B[t]
|
||||
|
||||
AtP = mm(A.T, P)
|
||||
BtP = mm(B.T, P)
|
||||
G = R + mm(BtP, B)
|
||||
H = mm(BtP, A) + M.T
|
||||
h = r + mv(B.T, p)
|
||||
K_ = -jnp.linalg.solve(G + EPS * jnp.eye(G.shape[0]), H)
|
||||
k_ = -jnp.linalg.solve(G + EPS * jnp.eye(G.shape[0]), h)
|
||||
P_ = Q + mm(AtP, A) + mm(K_.T, H)
|
||||
p_ = q + mv(A.T, p) + mv(K_.T, h)
|
||||
|
||||
K = K.at[t].set(K_)
|
||||
k = k.at[t].set(k_)
|
||||
return spec, P_, p_, K, k
|
||||
|
||||
_, P, p, K, k = fori_loop(
|
||||
0, T, rev_loop,
|
||||
(spec, spec.Q[T + 1], spec.q[T + 1], K, k))
|
||||
|
||||
return K, k
|
||||
|
||||
|
||||
def lqr_predict(spec, x0):
|
||||
T, control_dim, _ = spec.R.shape
|
||||
_, state_dim, _ = spec.Q.shape
|
||||
|
||||
K, k = lqr_solve(spec)
|
||||
|
||||
def fwd_loop(t, state):
|
||||
spec, X, U = state
|
||||
A, B = spec.A[t], spec.B[t]
|
||||
u = mv(K[t], X[t]) + k[t]
|
||||
x = mv(A, X[t]) + mv(B, u)
|
||||
X = X.at[t + 1].set(x)
|
||||
U = U.at[t].set(u)
|
||||
return spec, X, U
|
||||
|
||||
U = jnp.zeros((T, control_dim))
|
||||
X = jnp.zeros((T + 1, state_dim))
|
||||
X = X.at[0].set(x0)
|
||||
_, X, U = fori_loop(0, T, fwd_loop, (spec, X, U))
|
||||
return X, U
|
||||
|
||||
|
||||
def ilqr(iterations, p, x0, U):
|
||||
assert x0.ndim == 1 and x0.shape[0] == p.state_dim, x0.shape
|
||||
assert U.ndim > 0 and U.shape[0] == p.horizon, (U.shape, p.horizon)
|
||||
|
||||
lqr_approx = make_lqr_approx(p)
|
||||
|
||||
def loop(_, state):
|
||||
X, U = state
|
||||
p_lqr = lqr_approx(X, U)
|
||||
dX, dU = lqr_predict(p_lqr, jnp.zeros_like(x0))
|
||||
U = U + dU
|
||||
X = trajectory(p.dynamics, U, X[0] + dX[0])
|
||||
return X, U
|
||||
|
||||
X = trajectory(p.dynamics, U, x0)
|
||||
return fori_loop(0, iterations, loop, (X, U))
|
||||
|
||||
|
||||
def mpc_predict(solver, p, x0, U):
|
||||
assert x0.ndim == 1 and x0.shape[0] == p.state_dim
|
||||
T = p.horizon
|
||||
|
||||
def zero_padded_controls_window(U, t):
|
||||
U_pad = jnp.vstack((U, jnp.zeros(U.shape)))
|
||||
return lax.dynamic_slice_in_dim(U_pad, t, T, axis=0)
|
||||
|
||||
def loop(t, state):
|
||||
cost = lambda t_, x, u: p.cost(t + t_, x, u)
|
||||
dyns = lambda t_, x, u: p.dynamics(t + t_, x, u)
|
||||
|
||||
X, U = state
|
||||
p_ = ControlSpec(cost, dyns, T, p.state_dim, p.control_dim)
|
||||
xt = X[t]
|
||||
U_rem = zero_padded_controls_window(U, t)
|
||||
_, U_ = solver(p_, xt, U_rem)
|
||||
ut = U_[0]
|
||||
x = p.dynamics(t, xt, ut)
|
||||
X = X.at[t + 1].set(x)
|
||||
U = U.at[t].set(ut)
|
||||
return X, U
|
||||
|
||||
X = jnp.zeros((T + 1, p.state_dim))
|
||||
X = X.at[0].set(x0)
|
||||
return fori_loop(0, T, loop, (X, U))
|
@ -1,245 +0,0 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import zlib
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from examples import control
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def one_step_lqr(dim, T):
|
||||
Q = jnp.stack(T * (jnp.eye(dim),))
|
||||
q = jnp.zeros((T, dim))
|
||||
R = jnp.zeros((T, dim, dim))
|
||||
r = jnp.zeros((T, dim))
|
||||
M = jnp.zeros((T, dim, dim))
|
||||
A = jnp.stack(T * (jnp.eye(dim),))
|
||||
B = jnp.stack(T * (jnp.eye(dim),))
|
||||
return control.LqrSpec(Q, q, R, r, M, A, B)
|
||||
|
||||
|
||||
def control_from_lqr(lqr):
|
||||
T, dim, _ = lqr.Q.shape
|
||||
dot = jnp.dot
|
||||
|
||||
def cost(t, x, u):
|
||||
return (
|
||||
dot(dot(lqr.Q[t], x), x) + dot(lqr.q[t], x) +
|
||||
dot(dot(lqr.R[t], u), u) + dot(lqr.r[t], u) +
|
||||
dot(dot(lqr.M[t], u), x))
|
||||
|
||||
def dynamics(t, x, u):
|
||||
return dot(lqr.A[t], x) + dot(lqr.B[t], u)
|
||||
|
||||
return control.ControlSpec(cost, dynamics, T, dim, dim)
|
||||
|
||||
|
||||
def one_step_control(dim, T):
|
||||
|
||||
def cost(t, x, u):
|
||||
return jnp.dot(x, x)
|
||||
|
||||
def dynamics(t, x, u):
|
||||
return x + u
|
||||
|
||||
return control.ControlSpec(cost, dynamics, T, dim, dim)
|
||||
|
||||
|
||||
class ControlExampleTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.rng = np.random.default_rng(zlib.adler32(self.__class__.__name__.encode()))
|
||||
|
||||
def testTrajectoryCyclicIntegerCounter(self):
|
||||
num_states = 3
|
||||
|
||||
def dynamics(t, x, u):
|
||||
return (x + u) % num_states
|
||||
|
||||
T = 10
|
||||
|
||||
U = jnp.ones((T, 1))
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.arange(T + 1) % num_states
|
||||
expected = jnp.reshape(expected, (T + 1, 1))
|
||||
np.testing.assert_allclose(X, expected)
|
||||
|
||||
U = 2 * jnp.ones((T, 1))
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.cumsum(2 * jnp.ones(T)) % num_states
|
||||
expected = jnp.concatenate((jnp.zeros(1), expected))
|
||||
expected = jnp.reshape(expected, (T + 1, 1))
|
||||
np.testing.assert_allclose(X, expected)
|
||||
|
||||
def testTrajectoryTimeVarying(self):
|
||||
T = 6
|
||||
|
||||
def clip(x, lo, hi):
|
||||
return jnp.minimum(hi, jnp.maximum(lo, x))
|
||||
|
||||
def dynamics(t, x, u):
|
||||
# computes `(x + u) if t > T else 0`
|
||||
return (x + u) * clip(t - T, 0, 1)
|
||||
|
||||
U = jnp.ones((2 * T, 1))
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T, dtype=float)))
|
||||
expected = jnp.reshape(expected, (2 * T + 1, 1))
|
||||
np.testing.assert_allclose(X, expected)
|
||||
|
||||
|
||||
def testTrajectoryCyclicIndicator(self):
|
||||
num_states = 3
|
||||
|
||||
def position(x):
|
||||
'''finds the index of a standard basis vector, e.g. [0, 1, 0] -> 1'''
|
||||
x = jnp.cumsum(x)
|
||||
x = 1 - x
|
||||
return jnp.sum(x, dtype=jnp.int32)
|
||||
|
||||
def dynamics(t, x, u):
|
||||
'''moves the next standard basis vector'''
|
||||
idx = (position(x) + u[0]) % num_states
|
||||
return lax.dynamic_slice_in_dim(jnp.eye(num_states, dtype=jnp.int32), idx, 1)[0]
|
||||
|
||||
T = 8
|
||||
|
||||
U = jnp.ones((T, 1), dtype=jnp.int32)
|
||||
X = control.trajectory(dynamics, U, jnp.eye(num_states, dtype=jnp.int32)[0])
|
||||
expected = jnp.vstack((jnp.eye(num_states),) * 3)
|
||||
np.testing.assert_allclose(X, expected)
|
||||
|
||||
|
||||
def testLqrSolve(self):
|
||||
dim, T = 2, 10
|
||||
p = one_step_lqr(dim, T)
|
||||
K, k = control.lqr_solve(p)
|
||||
K_ = -jnp.stack(T * (jnp.eye(dim),))
|
||||
np.testing.assert_allclose(K, K_, atol=1e-6, rtol=1e-6)
|
||||
np.testing.assert_allclose(k, jnp.zeros((T, dim)))
|
||||
|
||||
|
||||
def testLqrPredict(self):
|
||||
dim, T = 2, 10
|
||||
p = one_step_lqr(dim, T)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
X, U = control.lqr_predict(p, x0)
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0,
|
||||
atol=1e-6, rtol=1e-6)
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)),
|
||||
atol=1e-6, rtol=1e-6)
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)),
|
||||
atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
def testIlqrWithLqrProblem(self):
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
lqr = one_step_lqr(dim, T)
|
||||
p = control_from_lqr(lqr)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)), atol=1E-15)
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)), atol=1E-15)
|
||||
|
||||
|
||||
def testIlqrWithLqrProblemSpecifiedGenerally(self):
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
p = one_step_control(dim, T)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)), atol=1E-15)
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)), atol=1E-15)
|
||||
|
||||
|
||||
def testIlqrWithNonlinearProblem(self):
|
||||
def cost(t, x, u):
|
||||
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1).astype(x.dtype)
|
||||
|
||||
def dynamics(t, x, u):
|
||||
return (x ** 2. - u ** 2.) / (t + 1).astype(x.dtype)
|
||||
|
||||
T, num_iters, d = 10, 7, 1
|
||||
p = control.ControlSpec(cost, dynamics, T, d, d)
|
||||
|
||||
x0 = jnp.array([0.2])
|
||||
X, U = control.ilqr(num_iters, p, x0, 1e-5 * jnp.ones((T, d)))
|
||||
assert_close = partial(np.testing.assert_allclose, atol=1e-2)
|
||||
assert_close(X[0], x0)
|
||||
assert_close(U[0] ** 2., x0 ** 2.)
|
||||
assert_close(X[1:], jnp.zeros((T, d)))
|
||||
assert_close(U[1:], jnp.zeros((T - 1, d)))
|
||||
|
||||
|
||||
def testMpcWithLqrProblem(self):
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
lqr = one_step_lqr(dim, T)
|
||||
p = control_from_lqr(lqr)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)))
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)))
|
||||
|
||||
|
||||
def testMpcWithLqrProblemSpecifiedGenerally(self):
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
p = one_step_control(dim, T)
|
||||
x0 = jnp.array(self.rng.normal(size=dim))
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)))
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)))
|
||||
|
||||
|
||||
def testMpcWithNonlinearProblem(self):
|
||||
def cost(t, x, u):
|
||||
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1).astype(x.dtype)
|
||||
|
||||
def dynamics(t, x, u):
|
||||
return (x ** 2. - u ** 2.) / (t + 1).astype(x.dtype)
|
||||
|
||||
T, num_iters, d = 10, 7, 1
|
||||
p = control.ControlSpec(cost, dynamics, T, d, d)
|
||||
|
||||
x0 = jnp.array([0.2])
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, 1e-5 * jnp.ones((T, d)))
|
||||
assert_close = partial(np.testing.assert_allclose, atol=1e-2)
|
||||
assert_close(X[0], x0)
|
||||
assert_close(U[0] ** 2., x0 ** 2.)
|
||||
assert_close(X[1:], jnp.zeros((T, d)))
|
||||
assert_close(U[1:], jnp.zeros((T - 1, d)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user