mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add finite test, add sep lims for binary_check
This commit is contained in:
parent
8a65e9da60
commit
8503656ea8
@ -56,6 +56,32 @@ class JetTest(jtu.JaxTestCase):
|
||||
check_dtypes=True):
|
||||
y, terms = jet(fun, primals, series)
|
||||
expected_y, expected_terms = jvp_taylor(fun, primals, series)
|
||||
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
# TODO(duvenaud): Lower zero_series to actual zeros automatically.
|
||||
if terms == zero_series:
|
||||
terms = tree_map(np.zeros_like, expected_terms)
|
||||
|
||||
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5,
|
||||
check_dtypes=True):
|
||||
|
||||
y, terms = jet(fun, primals, series)
|
||||
expected_y, expected_terms = jvp_taylor(fun, primals, series)
|
||||
|
||||
def _convert(x):
|
||||
return np.where(np.isfinite(x), x, np.nan)
|
||||
|
||||
y = _convert(y)
|
||||
expected_y = _convert(expected_y)
|
||||
|
||||
terms = _convert(np.asarray(terms))
|
||||
expected_terms = _convert(np.asarray(expected_terms))
|
||||
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
@ -111,14 +137,21 @@ class JetTest(jtu.JaxTestCase):
|
||||
terms_in = [rng.randn(*dims) for _ in range(order)]
|
||||
self.check_jet(fun, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def binary_check(self, fun, lims=[-2, 2], order=3):
|
||||
def binary_check(self, fun, lims=[-2, 2], order=3, finite=True):
|
||||
dims = 2, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
primal_in = (transform(lims, rng.rand(*dims)),
|
||||
transform(lims, rng.rand(*dims)))
|
||||
if isinstance(lims, tuple):
|
||||
x_lims, y_lims = lims
|
||||
else:
|
||||
x_lims, y_lims = lims, lims
|
||||
primal_in = (transform(x_lims, rng.rand(*dims)),
|
||||
transform(y_lims, rng.rand(*dims)))
|
||||
series_in = ([rng.randn(*dims) for _ in range(order)],
|
||||
[rng.randn(*dims) for _ in range(order)])
|
||||
self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
|
||||
if finite:
|
||||
self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
|
||||
else:
|
||||
self.check_jet_finite(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_exp(self): self.unary_check(np.exp)
|
||||
@ -182,7 +215,7 @@ class JetTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_xor(self): self.binary_check(lambda x, y: np.logical_xor(x, y))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_pow2(self): self.binary_check(lambda x, y: x ** y, lims=[0.2, 15])
|
||||
def test_pow(self): self.binary_check(lambda x, y: x ** y, lims=([0.2, 500], [-500, 500]), finite=False)
|
||||
|
||||
def test_process_call(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user