misc test skips

This commit is contained in:
Matthew Johnson 2019-04-15 09:14:43 -07:00
parent 909a9db881
commit c5a381ed4d
2 changed files with 7 additions and 1 deletions

View File

@ -725,11 +725,15 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)
for rng in [jtu.rand_default()]))
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
rng, indexer, op):
rng, indexer, op):
if FLAGS.jax_test_dut in ("cpu", "tpu") and not shape:
# TODO(b/127315062): this case causes an XLA crash on CPU/TPU. Reenable
# when fixed.
raise unittest.SkipTest("Test case crashes on CPU")
if (FLAGS.jax_test_dut == "tpu" and isinstance(indexer, slice)
and onp.zeros(shape)[indexer].size == 0):
# TODO(phawkins): this case causes an XLA crash on TPU. Reenable when fixed.
raise unittest.SkipTest("Test case crashes on TPU")
jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add
jax_fn = lambda x, y: jax_op(x, indexer, y)

View File

@ -190,6 +190,7 @@ class LaxRandomTest(jtu.JaxTestCase):
"a": a, "dtype": onp.dtype(dtype).name}
for a in [0.1, 1., 10.]
for dtype in [onp.float32, onp.float64]))
@jtu.skip_on_devices("tpu") # TODO(phawkins): re-enable
def testGamma(self, a, dtype):
key = random.PRNGKey(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
@ -201,6 +202,7 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
@jtu.skip_on_devices("tpu") # TODO(phawkins): re-enable
def testGammaShape(self):
key = random.PRNGKey(0)
x = random.gamma(key, onp.array([0.2, 0.3]), shape=(3, 2))