Merge pull request #1047 from hawkinsp/tests

Enable some tests that now pass.
This commit is contained in:
Peter Hawkins 2019-07-21 22:43:56 +01:00 committed by GitHub
commit f31d58fcd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 4 additions and 24 deletions

View File

@ -734,11 +734,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
"rng": jtu.rand_default(), "lnp_op": getattr(lnp, op),
"onp_op": getattr(onp, op)}
for op in ["cumsum", "cumprod"]
# TODO(phawkins): replace both type lists with default_dtypes after a
# Jaxlib update includes
# https://github.com/google/jax/commit/86f5d189cf563b027c3cd00eea38072c003905c8
for dtype in [onp.float32, onp.int32]
for out_dtype in [onp.float32, onp.int32]
for dtype in default_dtypes
for out_dtype in default_dtypes
for shape in all_shapes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng):
@ -1689,8 +1686,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)),
check_dtypes=True)
# TODO(phawkins): enable test after Jaxlib 0.1.22 is released.
@unittest.skip("Requires Jaxlib >= 0.1.22.")
def testIssue777(self):
x = lnp.linspace(-200, 0, 4, dtype=onp.float32)
f = api.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x))))

View File

@ -121,8 +121,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
if test_autodiff:
jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3)
# TODO(phawkins): enable test after Jaxlib 0.1.22 is released.
@unittest.skip("Requires Jaxlib >= 0.1.22.")
def testIssue980(self):
x = onp.full((4,), -1e20, dtype=onp.float32)
self.assertAllClose(onp.zeros((4,), dtype=onp.float32),

View File

@ -2323,10 +2323,8 @@ class LaxVmapTest(jtu.JaxTestCase):
if (lhs_bdim, rhs_bdim) != (None, None)
for rng in [jtu.rand_default()]
))
# TODO(mattjj): some cases fail on CPU with the latest XLA (jaxlib) release
# apparently because of an AVX512 issue, and some cases fail on TPU just due
# to numerical tolerances
@jtu.skip_on_devices("cpu", "tpu")
# TODO(mattjj): some cases fail on TPU just due to numerical tolerances
@jtu.skip_on_devices("tpu")
def testConvGeneralDilatedBatching(
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
dimension_numbers, perms, feature_group_count, lhs_bdim, rhs_bdim, rng):

View File

@ -50,11 +50,6 @@ def _skip_if_unsupported_type(dtype):
if (not FLAGS.jax_enable_x64 and
dtype in (onp.dtype('float64'), onp.dtype('complex128'))):
raise unittest.SkipTest("--jax_enable_x64 is not set")
if FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("gpu"):
# TODO(b/129698548): enable complex128 tests when XLA/GPU has better
# complex128 support.
if dtype == onp.dtype('complex128'):
raise unittest.SkipTest("XLA/GPU complex128 support is incomplete.")
numpy_version = tuple(map(int, onp.version.version.split('.')))

View File

@ -252,7 +252,6 @@ class LaxRandomTest(jtu.JaxTestCase):
"a": a, "dtype": onp.dtype(dtype).name}
for a in [0.1, 1., 10.]
for dtype in [onp.float32, onp.float64]))
@jtu.skip_on_devices("tpu") # TODO(b/130544008): re-enable when XLA fixed
def testGamma(self, a, dtype):
key = random.PRNGKey(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
@ -264,7 +263,6 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
@jtu.skip_on_devices("tpu") # TODO(b/130544008): re-enable when XLA fixed
def testGammaShape(self):
key = random.PRNGKey(0)
x = random.gamma(key, onp.array([0.2, 0.3]), shape=(3, 2))

View File

@ -164,8 +164,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
# TODO: currently it ignores the argument "shapes" and only tests dim=4
@genNamedParametersNArgs(3, jtu.rand_default())
# TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
def testMultivariateNormalLogPdf(self, rng, shapes, dtypes):
scipy_fun = osp_stats.multivariate_normal.logpdf
lax_fun = lsp_stats.multivariate_normal.logpdf
@ -283,8 +281,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
# TODO(phawkins): enable test after Jaxlib 0.1.22 is released.
@unittest.skip("Requires Jaxlib >= 0.1.22.")
def testIssue972(self):
self.assertAllClose(
onp.ones((4,), onp.float32),