examples tests: avoid use of private jax utilities

This commit is contained in:
Jake VanderPlas 2021-12-10 11:42:36 -08:00
parent 6880e2f086
commit 3f1d21ad73
2 changed files with 59 additions and 62 deletions

View File

@ -13,19 +13,19 @@
# limitations under the License.
from functools import partial
import zlib
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from jax import lax
# TODO(jakevdp) avoid dependence on private test_util in examples
from jax._src import test_util as jtu
import jax.numpy as jnp
from examples import control
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
def one_step_lqr(dim, T):
@ -66,7 +66,10 @@ def one_step_control(dim, T):
return control.ControlSpec(cost, dynamics, T, dim, dim)
class ControlExampleTest(jtu.JaxTestCase):
class ControlExampleTest(parameterized.TestCase):
def setUp(self):
self.rng = np.random.default_rng(zlib.adler32(self.__class__.__name__.encode()))
def testTrajectoryCyclicIntegerCounter(self):
num_states = 3
@ -80,14 +83,14 @@ class ControlExampleTest(jtu.JaxTestCase):
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.arange(T + 1) % num_states
expected = jnp.reshape(expected, (T + 1, 1))
self.assertAllClose(X, expected, check_dtypes=False)
np.testing.assert_allclose(X, expected)
U = 2 * jnp.ones((T, 1))
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.cumsum(2 * jnp.ones(T)) % num_states
expected = jnp.concatenate((jnp.zeros(1), expected))
expected = jnp.reshape(expected, (T + 1, 1))
self.assertAllClose(X, expected, check_dtypes=False)
np.testing.assert_allclose(X, expected)
def testTrajectoryTimeVarying(self):
T = 6
@ -103,7 +106,7 @@ class ControlExampleTest(jtu.JaxTestCase):
X = control.trajectory(dynamics, U, jnp.zeros(1))
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T)))
expected = jnp.reshape(expected, (2 * T + 1, 1))
self.assertAllClose(X, expected, check_dtypes=True)
np.testing.assert_allclose(X, expected)
def testTrajectoryCyclicIndicator(self):
@ -125,7 +128,7 @@ class ControlExampleTest(jtu.JaxTestCase):
U = jnp.ones((T, 1), dtype=jnp.int32)
X = control.trajectory(dynamics, U, jnp.eye(num_states, dtype=jnp.int32)[0])
expected = jnp.vstack((jnp.eye(num_states),) * 3)
self.assertAllClose(X, expected, check_dtypes=True)
np.testing.assert_allclose(X, expected)
def testLqrSolve(self):
@ -133,48 +136,45 @@ class ControlExampleTest(jtu.JaxTestCase):
p = one_step_lqr(dim, T)
K, k = control.lqr_solve(p)
K_ = -jnp.stack(T * (jnp.eye(dim),))
self.assertAllClose(K, K_, check_dtypes=True, atol=1e-6, rtol=1e-6)
self.assertAllClose(k, jnp.zeros((T, dim)), check_dtypes=True)
np.testing.assert_allclose(K, K_, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(k, jnp.zeros((T, dim)))
def testLqrPredict(self):
randn = self.rng().randn
dim, T = 2, 10
p = one_step_lqr(dim, T)
x0 = randn(dim)
x0 = self.rng.normal(size=dim)
X, U = control.lqr_predict(p, x0)
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True,
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0,
atol=1e-6, rtol=1e-6)
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True,
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)),
atol=1e-6, rtol=1e-6)
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True,
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)),
atol=1e-6, rtol=1e-6)
def testIlqrWithLqrProblem(self):
randn = self.rng().randn
dim, T, num_iters = 2, 10, 3
lqr = one_step_lqr(dim, T)
p = control_from_lqr(lqr)
x0 = randn(dim)
x0 = self.rng.normal(size=dim)
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True)
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)), atol=1E-15)
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)), atol=1E-15)
def testIlqrWithLqrProblemSpecifiedGenerally(self):
randn = self.rng().randn
dim, T, num_iters = 2, 10, 3
p = one_step_control(dim, T)
x0 = randn(dim)
x0 = self.rng.normal(size=dim)
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True)
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)), atol=1E-15)
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)), atol=1E-15)
def testIlqrWithNonlinearProblem(self):
@ -189,7 +189,7 @@ class ControlExampleTest(jtu.JaxTestCase):
x0 = jnp.array([0.2])
X, U = control.ilqr(num_iters, p, x0, 1e-5 * jnp.ones((T, d)))
assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
assert_close = partial(np.testing.assert_allclose, atol=1e-2)
assert_close(X[0], x0)
assert_close(U[0] ** 2., x0 ** 2.)
assert_close(X[1:], jnp.zeros((T, d)))
@ -197,30 +197,28 @@ class ControlExampleTest(jtu.JaxTestCase):
def testMpcWithLqrProblem(self):
randn = self.rng().randn
dim, T, num_iters = 2, 10, 3
lqr = one_step_lqr(dim, T)
p = control_from_lqr(lqr)
x0 = randn(dim)
x0 = self.rng.normal(size=dim)
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True)
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)))
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)))
def testMpcWithLqrProblemSpecifiedGenerally(self):
randn = self.rng().randn
dim, T, num_iters = 2, 10, 3
p = one_step_control(dim, T)
x0 = randn(dim)
x0 = self.rng.normal(size=dim)
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
self.assertAllClose(X[0], x0, check_dtypes=True)
self.assertAllClose(U[0], -x0, check_dtypes=True)
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
np.testing.assert_allclose(X[0], x0)
np.testing.assert_allclose(U[0], -x0)
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)))
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)))
def testMpcWithNonlinearProblem(self):
@ -236,7 +234,7 @@ class ControlExampleTest(jtu.JaxTestCase):
x0 = jnp.array([0.2])
solver = partial(control.ilqr, num_iters)
X, U = control.mpc_predict(solver, p, x0, 1e-5 * jnp.ones((T, d)))
assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
assert_close = partial(np.testing.assert_allclose, atol=1e-2)
assert_close(X[0], x0)
assert_close(U[0] ** 2., x0 ** 2.)
assert_close(X[1:], jnp.zeros((T, d)))
@ -244,4 +242,4 @@ class ControlExampleTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
absltest.main()

View File

@ -15,13 +15,15 @@
import os
import sys
import unittest
import zlib
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from jax import lax
# TODO(jakevdp) avoid dependence on private test_util in examples.
from jax._src import test_util as jtu
from jax import random
import jax.numpy as jnp
@ -32,24 +34,25 @@ sys.path.pop()
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
jax_rng = random.PRNGKey(0)
result_shape, params = init_fun(jax_rng, input_shape)
rng = test_case.rng()
result = apply_fun(params, rng.randn(*input_shape).astype(dtype="float32"))
result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32"))
test_case.assertEqual(result.shape, result_shape)
class ExamplesTest(jtu.JaxTestCase):
class ExamplesTest(parameterized.TestCase):
def setUp(self):
self.rng = np.random.default_rng(zlib.adler32(self.__class__.__name__.encode()))
@parameterized.named_parameters(
{"testcase_name": "_input_shape={}".format(input_shape),
"input_shape": input_shape}
for input_shape in [(2, 20, 25, 2)])
@jtu.skip_on_flag('jax_enable_x64', True)
@unittest.skipIf(config.x64_enabled, "skip in x64 mode")
def testIdentityBlockShape(self, input_shape):
init_fun, apply_fun = resnet50.IdentityBlock(2, (4, 3))
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@ -58,7 +61,7 @@ class ExamplesTest(jtu.JaxTestCase):
{"testcase_name": "_input_shape={}".format(input_shape),
"input_shape": input_shape}
for input_shape in [(2, 20, 25, 3)])
@jtu.skip_on_flag('jax_enable_x64', True)
@unittest.skipIf(config.x64_enabled, "skip in x64 mode")
def testConvBlockShape(self, input_shape):
init_fun, apply_fun = resnet50.ConvBlock(3, (2, 3, 4))
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@ -69,30 +72,26 @@ class ExamplesTest(jtu.JaxTestCase):
"num_classes": num_classes, "input_shape": input_shape}
for num_classes in [5, 10]
for input_shape in [(224, 224, 3, 2)])
@jtu.skip_on_flag('jax_enable_x64', True)
@unittest.skipIf(config.x64_enabled, "skip in x64 mode")
def testResNet50Shape(self, num_classes, input_shape):
init_fun, apply_fun = resnet50.ResNet50(num_classes)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
def testKernelRegressionGram(self):
n, d = 100, 20
rng = self.rng()
xs = rng.randn(n, d)
xs = self.rng.normal(size=(n, d))
kernel = lambda x, y: jnp.dot(x, y)
self.assertAllClose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T),
check_dtypes=False, atol=1E-5)
np.testing.assert_allclose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T), atol=1E-5)
def testKernelRegressionTrainAndPredict(self):
n, d = 100, 20
rng = self.rng()
truth = rng.randn(d)
xs = rng.randn(n, d)
truth = self.rng.normal(size=d)
xs = self.rng.normal(size=(n, d))
ys = jnp.dot(xs, truth)
kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
predict = kernel_lsq.train(kernel, xs, ys)
self.assertAllClose(predict(xs), ys, atol=1e-3, rtol=1e-3,
check_dtypes=False)
np.testing.assert_allclose(predict(xs), ys, atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
absltest.main()