mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #1047 from hawkinsp/tests
Enable some tests that now pass.
This commit is contained in:
commit
f31d58fcd8
@ -734,11 +734,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
"rng": jtu.rand_default(), "lnp_op": getattr(lnp, op),
|
||||
"onp_op": getattr(onp, op)}
|
||||
for op in ["cumsum", "cumprod"]
|
||||
# TODO(phawkins): replace both type lists with default_dtypes after a
|
||||
# Jaxlib update includes
|
||||
# https://github.com/google/jax/commit/86f5d189cf563b027c3cd00eea38072c003905c8
|
||||
for dtype in [onp.float32, onp.int32]
|
||||
for out_dtype in [onp.float32, onp.int32]
|
||||
for dtype in default_dtypes
|
||||
for out_dtype in default_dtypes
|
||||
for shape in all_shapes
|
||||
for axis in [None] + list(range(-len(shape), len(shape)))))
|
||||
def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng):
|
||||
@ -1689,8 +1686,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)),
|
||||
check_dtypes=True)
|
||||
|
||||
# TODO(phawkins): enable test after Jaxlib 0.1.22 is released.
|
||||
@unittest.skip("Requires Jaxlib >= 0.1.22.")
|
||||
def testIssue777(self):
|
||||
x = lnp.linspace(-200, 0, 4, dtype=onp.float32)
|
||||
f = api.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x))))
|
||||
|
@ -121,8 +121,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
if test_autodiff:
|
||||
jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3)
|
||||
|
||||
# TODO(phawkins): enable test after Jaxlib 0.1.22 is released.
|
||||
@unittest.skip("Requires Jaxlib >= 0.1.22.")
|
||||
def testIssue980(self):
|
||||
x = onp.full((4,), -1e20, dtype=onp.float32)
|
||||
self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
|
||||
|
@ -2323,10 +2323,8 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
if (lhs_bdim, rhs_bdim) != (None, None)
|
||||
for rng in [jtu.rand_default()]
|
||||
))
|
||||
# TODO(mattjj): some cases fail on CPU with the latest XLA (jaxlib) release
|
||||
# apparently because of an AVX512 issue, and some cases fail on TPU just due
|
||||
# to numerical tolerances
|
||||
@jtu.skip_on_devices("cpu", "tpu")
|
||||
# TODO(mattjj): some cases fail on TPU just due to numerical tolerances
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testConvGeneralDilatedBatching(
|
||||
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
|
||||
dimension_numbers, perms, feature_group_count, lhs_bdim, rhs_bdim, rng):
|
||||
|
@ -50,11 +50,6 @@ def _skip_if_unsupported_type(dtype):
|
||||
if (not FLAGS.jax_enable_x64 and
|
||||
dtype in (onp.dtype('float64'), onp.dtype('complex128'))):
|
||||
raise unittest.SkipTest("--jax_enable_x64 is not set")
|
||||
if FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("gpu"):
|
||||
# TODO(b/129698548): enable complex128 tests when XLA/GPU has better
|
||||
# complex128 support.
|
||||
if dtype == onp.dtype('complex128'):
|
||||
raise unittest.SkipTest("XLA/GPU complex128 support is incomplete.")
|
||||
|
||||
|
||||
numpy_version = tuple(map(int, onp.version.version.split('.')))
|
||||
|
@ -252,7 +252,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
"a": a, "dtype": onp.dtype(dtype).name}
|
||||
for a in [0.1, 1., 10.]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
@jtu.skip_on_devices("tpu") # TODO(b/130544008): re-enable when XLA fixed
|
||||
def testGamma(self, a, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
@ -264,7 +263,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
|
||||
|
||||
@jtu.skip_on_devices("tpu") # TODO(b/130544008): re-enable when XLA fixed
|
||||
def testGammaShape(self):
|
||||
key = random.PRNGKey(0)
|
||||
x = random.gamma(key, onp.array([0.2, 0.3]), shape=(3, 2))
|
||||
|
@ -164,8 +164,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
# TODO: currently it ignores the argument "shapes" and only tests dim=4
|
||||
@genNamedParametersNArgs(3, jtu.rand_default())
|
||||
# TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testMultivariateNormalLogPdf(self, rng, shapes, dtypes):
|
||||
scipy_fun = osp_stats.multivariate_normal.logpdf
|
||||
lax_fun = lsp_stats.multivariate_normal.logpdf
|
||||
@ -283,8 +281,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
# TODO(phawkins): enable test after Jaxlib 0.1.22 is released.
|
||||
@unittest.skip("Requires Jaxlib >= 0.1.22.")
|
||||
def testIssue972(self):
|
||||
self.assertAllClose(
|
||||
onp.ones((4,), onp.float32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user