Merge pull request #11629 from froystig:rm-control-example

PiperOrigin-RevId: 463462047
This commit is contained in:
jax authors 2022-07-26 17:12:51 -07:00
commit eb7040d6d6
2 changed files with 0 additions and 467 deletions

View File

@ -1,222 +0,0 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model-predictive non-linear control example.
"""
import collections
from jax import lax, grad, jacfwd, jacobian, vmap
import jax.numpy as jnp
# Specifies a general finite-horizon, time-varying control problem. Given cost
# function `c`, transition function `f`, and initial state `x0`, the goal is to
# compute:
#
# argmin(lambda X, U: c(T, X[T]) + sum(c(t, X[t], U[t]) for t in range(T)))
#
# subject to the constraints that `X[0] == x0` and that:
#
# all(X[t + 1] == f(X[t], U[t]) for t in range(T)) .
#
# The special case in which `c` is quadratic and `f` is linear is the
# linear-quadratic regulator (LQR) problem, and can be specified explicitly
# further below.
#
ControlSpec = collections.namedtuple(
'ControlSpec', 'cost dynamics horizon state_dim control_dim')
# Specifies a finite-horizon, time-varying LQR problem. Notation:
#
# cost(t, x, u) = sum(
# dot(x.T, Q[t], x) + dot(q[t], x) +
# dot(u.T, R[t], u) + dot(r[t], u) +
# dot(x.T, M[t], u)
#
# dynamics(t, x, u) = dot(A[t], x) + dot(B[t], u)
#
LqrSpec = collections.namedtuple('LqrSpec', 'Q q R r M A B')
dot = jnp.dot
mm = jnp.matmul
def mv(mat, vec):
assert mat.ndim == 2
assert vec.ndim == 1
return dot(mat, vec)
LOOP_VIA_SCAN = False
def fori_loop(lo, hi, loop, init):
if LOOP_VIA_SCAN:
return scan_fori_loop(lo, hi, loop, init)
else:
return lax.fori_loop(lo, hi, loop, init)
def scan_fori_loop(lo, hi, loop, init):
def scan_f(x, t):
return loop(t, x), ()
x, _ = lax.scan(scan_f, init, jnp.arange(lo, hi))
return x
def trajectory(dynamics, U, x0):
'''Unrolls `X[t+1] = dynamics(t, X[t], U[t])`, where `X[0] = x0`.'''
T, _ = U.shape
d, = x0.shape
X = jnp.zeros((T + 1, d), dtype=x0.dtype)
X = X.at[0].set(x0)
def loop(t, X):
x = dynamics(t, X[t], U[t])
X = X.at[t + 1].set(x)
return X
return fori_loop(0, T, loop, X)
def make_lqr_approx(p):
T = p.horizon
def approx_timestep(t, x, u):
M = jacfwd(grad(p.cost, argnums=2), argnums=1)(t, x, u).T
Q = jacfwd(grad(p.cost, argnums=1), argnums=1)(t, x, u)
R = jacfwd(grad(p.cost, argnums=2), argnums=2)(t, x, u)
q, r = grad(p.cost, argnums=(1, 2))(t, x, u)
A, B = jacobian(p.dynamics, argnums=(1, 2))(t, x, u)
return Q, q, R, r, M, A, B
_approx = vmap(approx_timestep)
def approx(X, U):
assert X.shape[0] == T + 1 and U.shape[0] == T
U_pad = jnp.vstack((U, jnp.zeros((1,) + U.shape[1:])))
Q, q, R, r, M, A, B = _approx(jnp.arange(T + 1), X, U_pad)
return LqrSpec(Q, q, R[:T], r[:T], M[:T], A[:T], B[:T])
return approx
def lqr_solve(spec):
EPS = 1e-7
T, control_dim, _ = spec.R.shape
_, state_dim, _ = spec.Q.shape
K = jnp.zeros((T, control_dim, state_dim))
k = jnp.zeros((T, control_dim))
def rev_loop(t_, state):
t = T - t_ - 1
spec, P, p, K, k = state
Q, q = spec.Q[t], spec.q[t]
R, r = spec.R[t], spec.r[t]
M = spec.M[t]
A, B = spec.A[t], spec.B[t]
AtP = mm(A.T, P)
BtP = mm(B.T, P)
G = R + mm(BtP, B)
H = mm(BtP, A) + M.T
h = r + mv(B.T, p)
K_ = -jnp.linalg.solve(G + EPS * jnp.eye(G.shape[0]), H)
k_ = -jnp.linalg.solve(G + EPS * jnp.eye(G.shape[0]), h)
P_ = Q + mm(AtP, A) + mm(K_.T, H)
p_ = q + mv(A.T, p) + mv(K_.T, h)
K = K.at[t].set(K_)
k = k.at[t].set(k_)
return spec, P_, p_, K, k
_, P, p, K, k = fori_loop(
0, T, rev_loop,
(spec, spec.Q[T + 1], spec.q[T + 1], K, k))
return K, k
def lqr_predict(spec, x0):
T, control_dim, _ = spec.R.shape
_, state_dim, _ = spec.Q.shape
K, k = lqr_solve(spec)
def fwd_loop(t, state):
spec, X, U = state
A, B = spec.A[t], spec.B[t]
u = mv(K[t], X[t]) + k[t]
x = mv(A, X[t]) + mv(B, u)
X = X.at[t + 1].set(x)
U = U.at[t].set(u)
return spec, X, U
U = jnp.zeros((T, control_dim))
X = jnp.zeros((T + 1, state_dim))
X = X.at[0].set(x0)
_, X, U = fori_loop(0, T, fwd_loop, (spec, X, U))
return X, U
def ilqr(iterations, p, x0, U):
assert x0.ndim == 1 and x0.shape[0] == p.state_dim, x0.shape
assert U.ndim > 0 and U.shape[0] == p.horizon, (U.shape, p.horizon)
lqr_approx = make_lqr_approx(p)
def loop(_, state):
X, U = state
p_lqr = lqr_approx(X, U)
dX, dU = lqr_predict(p_lqr, jnp.zeros_like(x0))
U = U + dU
X = trajectory(p.dynamics, U, X[0] + dX[0])
return X, U
X = trajectory(p.dynamics, U, x0)
return fori_loop(0, iterations, loop, (X, U))
def mpc_predict(solver, p, x0, U):
assert x0.ndim == 1 and x0.shape[0] == p.state_dim
T = p.horizon
def zero_padded_controls_window(U, t):
U_pad = jnp.vstack((U, jnp.zeros(U.shape)))
return lax.dynamic_slice_in_dim(U_pad, t, T, axis=0)
def loop(t, state):
cost = lambda t_, x, u: p.cost(t + t_, x, u)
dyns = lambda t_, x, u: p.dynamics(t + t_, x, u)
X, U = state
p_ = ControlSpec(cost, dyns, T, p.state_dim, p.control_dim)
xt = X[t]
U_rem = zero_padded_controls_window(U, t)
_, U_ = solver(p_, xt, U_rem)
ut = U_[0]
x = p.dynamics(t, xt, ut)
X = X.at[t + 1].set(x)
U = U.at[t].set(ut)
return X, U
X = jnp.zeros((T + 1, p.state_dim))
X = X.at[0].set(x0)
return fori_loop(0, T, loop, (X, U))

View File

@ -1,245 +0,0 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import zlib
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from jax import lax
import jax.numpy as jnp
from examples import control
from jax.config import config
config.parse_flags_with_absl()
def one_step_lqr(dim, T):
Q = jnp.stack(T * (jnp.eye(dim),))
q = jnp.zeros((T, dim))
R = jnp.zeros((T, dim, dim))
r = jnp.zeros((T, dim))
M = jnp.zeros((T, dim, dim))
A = jnp.stack(T * (jnp.eye(dim),))
B = jnp.stack(T * (jnp.eye(dim),))
return control.LqrSpec(Q, q, R, r, M, A, B)
def control_from_lqr(lqr):
T, dim, _ = lqr.Q.shape
dot = jnp.dot
def cost(t, x, u):
return (
dot(dot(lqr.Q[t], x), x) + dot(lqr.q[t], x) +
dot(dot(lqr.R[t], u), u) + dot(lqr.r[t], u) +
dot(dot(lqr.M[t], u), x))
def dynamics(t, x, u):
return dot(lqr.A[t], x) + dot(lqr.B[t], u)
return control.ControlSpec(cost, dynamics, T, dim, dim)
def one_step_control(dim, T):
def cost(t, x, u):
return jnp.dot(x, x)
def dynamics(t, x, u):
return x + u
return control.ControlSpec(cost, dynamics, T, dim, dim)
class ControlExampleTest(parameterized.TestCase):
def setUp(self):
self.rng = np.random.default_rng(zlib.adler32(self.__class__.__name__.encode()))
def testTrajectoryCyclicIntegerCounter(self):
num_states = 3
def dynamics(t, x, u):
return (x + u) % num_states
T = 10
U = jnp.ones((T, 1))
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.arange(T + 1) % num_states
expected = jnp.reshape(expected, (T + 1, 1))
np.testing.assert_allclose(X, expected)
U = 2 * jnp.ones((T, 1))
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.cumsum(2 * jnp.ones(T)) % num_states
expected = jnp.concatenate((jnp.zeros(1), expected))
expected = jnp.reshape(expected, (T + 1, 1))
np.testing.assert_allclose(X, expected)
def testTrajectoryTimeVarying(self):
T = 6
def clip(x, lo, hi):
return jnp.minimum(hi, jnp.maximum(lo, x))
def dynamics(t, x, u):
# computes `(x + u) if t > T else 0`
return (x + u) * clip(t - T, 0, 1)
U = jnp.ones((2 * T, 1))
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T, dtype=float)))
expected = jnp.reshape(expected, (2 * T + 1, 1))
np.testing.assert_allclose(X, expected)
def testTrajectoryCyclicIndicator(self):
num_states = 3
def position(x):
'''finds the index of a standard basis vector, e.g. [0, 1, 0] -> 1'''
x = jnp.cumsum(x)
x = 1 - x
return jnp.sum(x, dtype=jnp.int32)
def dynamics(t, x, u):
'''moves the next standard basis vector'''
idx = (position(x) + u[0]) % num_states
return lax.dynamic_slice_in_dim(jnp.eye(num_states, dtype=jnp.int32), idx, 1)[0]
T = 8
U = jnp.ones((T, 1), dtype=jnp.int32)
X = control.trajectory(dynamics, U, jnp.eye(num_states, dtype=jnp.int32)[0])
expected = jnp.vstack((jnp.eye(num_states),) * 3)
np.testing.assert_allclose(X, expected)
def testLqrSolve(self):
dim, T = 2, 10
p = one_step_lqr(dim, T)
K, k = control.lqr_solve(p)
K_ = -jnp.stack(T * (jnp.eye(dim),))
np.testing.assert_allclose(K, K_, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(k, jnp.zeros((T, dim)))
def testLqrPredict(self):
dim, T = 2, 10
p = one_step_lqr(dim, T)
x0 = jnp.array(self.rng.normal(size=dim))
X, U = control.lqr_predict(p, x0)
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0,
atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)),
atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)),
atol=1e-6, rtol=1e-6)
def testIlqrWithLqrProblem(self):
dim, T, num_iters = 2, 10, 3
lqr = one_step_lqr(dim, T)
p = control_from_lqr(lqr)
x0 = jnp.array(self.rng.normal(size=dim))
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)), atol=1E-15)
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)), atol=1E-15)
def testIlqrWithLqrProblemSpecifiedGenerally(self):
dim, T, num_iters = 2, 10, 3
p = one_step_control(dim, T)
x0 = jnp.array(self.rng.normal(size=dim))
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)), atol=1E-15)
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)), atol=1E-15)
def testIlqrWithNonlinearProblem(self):
def cost(t, x, u):
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1).astype(x.dtype)
def dynamics(t, x, u):
return (x ** 2. - u ** 2.) / (t + 1).astype(x.dtype)
T, num_iters, d = 10, 7, 1
p = control.ControlSpec(cost, dynamics, T, d, d)
x0 = jnp.array([0.2])
X, U = control.ilqr(num_iters, p, x0, 1e-5 * jnp.ones((T, d)))
assert_close = partial(np.testing.assert_allclose, atol=1e-2)
assert_close(X[0], x0)
assert_close(U[0] ** 2., x0 ** 2.)
assert_close(X[1:], jnp.zeros((T, d)))
assert_close(U[1:], jnp.zeros((T - 1, d)))
def testMpcWithLqrProblem(self):
dim, T, num_iters = 2, 10, 3
lqr = one_step_lqr(dim, T)
p = control_from_lqr(lqr)
x0 = jnp.array(self.rng.normal(size=dim))
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)))
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)))
def testMpcWithLqrProblemSpecifiedGenerally(self):
dim, T, num_iters = 2, 10, 3
p = one_step_control(dim, T)
x0 = jnp.array(self.rng.normal(size=dim))
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)))
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)))
def testMpcWithNonlinearProblem(self):
def cost(t, x, u):
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1).astype(x.dtype)
def dynamics(t, x, u):
return (x ** 2. - u ** 2.) / (t + 1).astype(x.dtype)
T, num_iters, d = 10, 7, 1
p = control.ControlSpec(cost, dynamics, T, d, d)
x0 = jnp.array([0.2])
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, 1e-5 * jnp.ones((T, d)))
assert_close = partial(np.testing.assert_allclose, atol=1e-2)
assert_close(X[0], x0)
assert_close(U[0] ** 2., x0 ** 2.)
assert_close(X[1:], jnp.zeros((T, d)))
assert_close(U[1:], jnp.zeros((T - 1, d)))
if __name__ == '__main__':
absltest.main()