mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
examples tests: avoid use of private jax utilities
This commit is contained in:
parent
6880e2f086
commit
3f1d21ad73
@ -13,19 +13,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import zlib
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
from jax import lax
|
||||
# TODO(jakevdp) avoid dependence on private test_util in examples
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
from examples import control
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def one_step_lqr(dim, T):
|
||||
@ -66,7 +66,10 @@ def one_step_control(dim, T):
|
||||
return control.ControlSpec(cost, dynamics, T, dim, dim)
|
||||
|
||||
|
||||
class ControlExampleTest(jtu.JaxTestCase):
|
||||
class ControlExampleTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.rng = np.random.default_rng(zlib.adler32(self.__class__.__name__.encode()))
|
||||
|
||||
def testTrajectoryCyclicIntegerCounter(self):
|
||||
num_states = 3
|
||||
@ -80,14 +83,14 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.arange(T + 1) % num_states
|
||||
expected = jnp.reshape(expected, (T + 1, 1))
|
||||
self.assertAllClose(X, expected, check_dtypes=False)
|
||||
np.testing.assert_allclose(X, expected)
|
||||
|
||||
U = 2 * jnp.ones((T, 1))
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.cumsum(2 * jnp.ones(T)) % num_states
|
||||
expected = jnp.concatenate((jnp.zeros(1), expected))
|
||||
expected = jnp.reshape(expected, (T + 1, 1))
|
||||
self.assertAllClose(X, expected, check_dtypes=False)
|
||||
np.testing.assert_allclose(X, expected)
|
||||
|
||||
def testTrajectoryTimeVarying(self):
|
||||
T = 6
|
||||
@ -103,7 +106,7 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
X = control.trajectory(dynamics, U, jnp.zeros(1))
|
||||
expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T)))
|
||||
expected = jnp.reshape(expected, (2 * T + 1, 1))
|
||||
self.assertAllClose(X, expected, check_dtypes=True)
|
||||
np.testing.assert_allclose(X, expected)
|
||||
|
||||
|
||||
def testTrajectoryCyclicIndicator(self):
|
||||
@ -125,7 +128,7 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
U = jnp.ones((T, 1), dtype=jnp.int32)
|
||||
X = control.trajectory(dynamics, U, jnp.eye(num_states, dtype=jnp.int32)[0])
|
||||
expected = jnp.vstack((jnp.eye(num_states),) * 3)
|
||||
self.assertAllClose(X, expected, check_dtypes=True)
|
||||
np.testing.assert_allclose(X, expected)
|
||||
|
||||
|
||||
def testLqrSolve(self):
|
||||
@ -133,48 +136,45 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
p = one_step_lqr(dim, T)
|
||||
K, k = control.lqr_solve(p)
|
||||
K_ = -jnp.stack(T * (jnp.eye(dim),))
|
||||
self.assertAllClose(K, K_, check_dtypes=True, atol=1e-6, rtol=1e-6)
|
||||
self.assertAllClose(k, jnp.zeros((T, dim)), check_dtypes=True)
|
||||
np.testing.assert_allclose(K, K_, atol=1e-6, rtol=1e-6)
|
||||
np.testing.assert_allclose(k, jnp.zeros((T, dim)))
|
||||
|
||||
|
||||
def testLqrPredict(self):
|
||||
randn = self.rng().randn
|
||||
dim, T = 2, 10
|
||||
p = one_step_lqr(dim, T)
|
||||
x0 = randn(dim)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
X, U = control.lqr_predict(p, x0)
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True,
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0,
|
||||
atol=1e-6, rtol=1e-6)
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True,
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)),
|
||||
atol=1e-6, rtol=1e-6)
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True,
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)),
|
||||
atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
def testIlqrWithLqrProblem(self):
|
||||
randn = self.rng().randn
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
lqr = one_step_lqr(dim, T)
|
||||
p = control_from_lqr(lqr)
|
||||
x0 = randn(dim)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True)
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)), atol=1E-15)
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)), atol=1E-15)
|
||||
|
||||
|
||||
def testIlqrWithLqrProblemSpecifiedGenerally(self):
|
||||
randn = self.rng().randn
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
p = one_step_control(dim, T)
|
||||
x0 = randn(dim)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
X, U = control.ilqr(num_iters, p, x0, jnp.zeros((T, dim)))
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True)
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)), atol=1E-15)
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)), atol=1E-15)
|
||||
|
||||
|
||||
def testIlqrWithNonlinearProblem(self):
|
||||
@ -189,7 +189,7 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
|
||||
x0 = jnp.array([0.2])
|
||||
X, U = control.ilqr(num_iters, p, x0, 1e-5 * jnp.ones((T, d)))
|
||||
assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
|
||||
assert_close = partial(np.testing.assert_allclose, atol=1e-2)
|
||||
assert_close(X[0], x0)
|
||||
assert_close(U[0] ** 2., x0 ** 2.)
|
||||
assert_close(X[1:], jnp.zeros((T, d)))
|
||||
@ -197,30 +197,28 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
def testMpcWithLqrProblem(self):
|
||||
randn = self.rng().randn
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
lqr = one_step_lqr(dim, T)
|
||||
p = control_from_lqr(lqr)
|
||||
x0 = randn(dim)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True)
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)))
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)))
|
||||
|
||||
|
||||
def testMpcWithLqrProblemSpecifiedGenerally(self):
|
||||
randn = self.rng().randn
|
||||
dim, T, num_iters = 2, 10, 3
|
||||
p = one_step_control(dim, T)
|
||||
x0 = randn(dim)
|
||||
x0 = self.rng.normal(size=dim)
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, jnp.zeros((T, dim)))
|
||||
self.assertAllClose(X[0], x0, check_dtypes=True)
|
||||
self.assertAllClose(U[0], -x0, check_dtypes=True)
|
||||
self.assertAllClose(X[1:], jnp.zeros((T, 2)), check_dtypes=True)
|
||||
self.assertAllClose(U[1:], jnp.zeros((T - 1, 2)), check_dtypes=True)
|
||||
np.testing.assert_allclose(X[0], x0)
|
||||
np.testing.assert_allclose(U[0], -x0)
|
||||
np.testing.assert_allclose(X[1:], jnp.zeros((T, 2)))
|
||||
np.testing.assert_allclose(U[1:], jnp.zeros((T - 1, 2)))
|
||||
|
||||
|
||||
def testMpcWithNonlinearProblem(self):
|
||||
@ -236,7 +234,7 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
x0 = jnp.array([0.2])
|
||||
solver = partial(control.ilqr, num_iters)
|
||||
X, U = control.mpc_predict(solver, p, x0, 1e-5 * jnp.ones((T, d)))
|
||||
assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
|
||||
assert_close = partial(np.testing.assert_allclose, atol=1e-2)
|
||||
assert_close(X[0], x0)
|
||||
assert_close(U[0] ** 2., x0 ** 2.)
|
||||
assert_close(X[1:], jnp.zeros((T, d)))
|
||||
@ -244,4 +242,4 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
absltest.main()
|
||||
|
@ -15,13 +15,15 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import zlib
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
# TODO(jakevdp) avoid dependence on private test_util in examples.
|
||||
from jax._src import test_util as jtu
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -32,24 +34,25 @@ sys.path.pop()
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
||||
jax_rng = random.PRNGKey(0)
|
||||
result_shape, params = init_fun(jax_rng, input_shape)
|
||||
rng = test_case.rng()
|
||||
result = apply_fun(params, rng.randn(*input_shape).astype(dtype="float32"))
|
||||
result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32"))
|
||||
test_case.assertEqual(result.shape, result_shape)
|
||||
|
||||
|
||||
class ExamplesTest(jtu.JaxTestCase):
|
||||
class ExamplesTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.rng = np.random.default_rng(zlib.adler32(self.__class__.__name__.encode()))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_input_shape={}".format(input_shape),
|
||||
"input_shape": input_shape}
|
||||
for input_shape in [(2, 20, 25, 2)])
|
||||
@jtu.skip_on_flag('jax_enable_x64', True)
|
||||
@unittest.skipIf(config.x64_enabled, "skip in x64 mode")
|
||||
def testIdentityBlockShape(self, input_shape):
|
||||
init_fun, apply_fun = resnet50.IdentityBlock(2, (4, 3))
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
@ -58,7 +61,7 @@ class ExamplesTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_input_shape={}".format(input_shape),
|
||||
"input_shape": input_shape}
|
||||
for input_shape in [(2, 20, 25, 3)])
|
||||
@jtu.skip_on_flag('jax_enable_x64', True)
|
||||
@unittest.skipIf(config.x64_enabled, "skip in x64 mode")
|
||||
def testConvBlockShape(self, input_shape):
|
||||
init_fun, apply_fun = resnet50.ConvBlock(3, (2, 3, 4))
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
@ -69,30 +72,26 @@ class ExamplesTest(jtu.JaxTestCase):
|
||||
"num_classes": num_classes, "input_shape": input_shape}
|
||||
for num_classes in [5, 10]
|
||||
for input_shape in [(224, 224, 3, 2)])
|
||||
@jtu.skip_on_flag('jax_enable_x64', True)
|
||||
@unittest.skipIf(config.x64_enabled, "skip in x64 mode")
|
||||
def testResNet50Shape(self, num_classes, input_shape):
|
||||
init_fun, apply_fun = resnet50.ResNet50(num_classes)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
|
||||
def testKernelRegressionGram(self):
|
||||
n, d = 100, 20
|
||||
rng = self.rng()
|
||||
xs = rng.randn(n, d)
|
||||
xs = self.rng.normal(size=(n, d))
|
||||
kernel = lambda x, y: jnp.dot(x, y)
|
||||
self.assertAllClose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T),
|
||||
check_dtypes=False, atol=1E-5)
|
||||
np.testing.assert_allclose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T), atol=1E-5)
|
||||
|
||||
def testKernelRegressionTrainAndPredict(self):
|
||||
n, d = 100, 20
|
||||
rng = self.rng()
|
||||
truth = rng.randn(d)
|
||||
xs = rng.randn(n, d)
|
||||
truth = self.rng.normal(size=d)
|
||||
xs = self.rng.normal(size=(n, d))
|
||||
ys = jnp.dot(xs, truth)
|
||||
kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
|
||||
predict = kernel_lsq.train(kernel, xs, ys)
|
||||
self.assertAllClose(predict(xs), ys, atol=1e-3, rtol=1e-3,
|
||||
check_dtypes=False)
|
||||
np.testing.assert_allclose(predict(xs), ys, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user