add finite test, add sep lims for binary_check

This commit is contained in:
Jacob Kelly 2020-04-09 11:16:00 -04:00
parent 8a65e9da60
commit 8503656ea8

View File

@ -56,6 +56,32 @@ class JetTest(jtu.JaxTestCase):
check_dtypes=True):
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
# TODO(duvenaud): Lower zero_series to actual zeros automatically.
if terms == zero_series:
terms = tree_map(np.zeros_like, expected_terms)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5,
check_dtypes=True):
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)
def _convert(x):
return np.where(np.isfinite(x), x, np.nan)
y = _convert(y)
expected_y = _convert(expected_y)
terms = _convert(np.asarray(terms))
expected_terms = _convert(np.asarray(expected_terms))
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
@ -111,14 +137,21 @@ class JetTest(jtu.JaxTestCase):
terms_in = [rng.randn(*dims) for _ in range(order)]
self.check_jet(fun, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
def binary_check(self, fun, lims=[-2, 2], order=3):
def binary_check(self, fun, lims=[-2, 2], order=3, finite=True):
dims = 2, 3
rng = onp.random.RandomState(0)
primal_in = (transform(lims, rng.rand(*dims)),
transform(lims, rng.rand(*dims)))
if isinstance(lims, tuple):
x_lims, y_lims = lims
else:
x_lims, y_lims = lims, lims
primal_in = (transform(x_lims, rng.rand(*dims)),
transform(y_lims, rng.rand(*dims)))
series_in = ([rng.randn(*dims) for _ in range(order)],
[rng.randn(*dims) for _ in range(order)])
self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
if finite:
self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
else:
self.check_jet_finite(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
@jtu.skip_on_devices("tpu")
def test_exp(self): self.unary_check(np.exp)
@ -182,7 +215,7 @@ class JetTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def test_xor(self): self.binary_check(lambda x, y: np.logical_xor(x, y))
@jtu.skip_on_devices("tpu")
def test_pow2(self): self.binary_check(lambda x, y: x ** y, lims=[0.2, 15])
def test_pow(self): self.binary_check(lambda x, y: x ** y, lims=([0.2, 500], [-500, 500]), finite=False)
def test_process_call(self):
def f(x):