Disable failing test (#1744)

This commit is contained in:
Skye Wanderman-Milne 2019-11-22 11:34:14 -08:00 committed by GitHub
parent 6f3cb1c3ee
commit 3d1d140acd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1335,7 +1335,7 @@ class LaxTest(jtu.JaxTestCase):
]
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
for rng_factory in [jtu.rand_default]))
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
rng_idx_factory):
rng = rng_factory()
rng_idx = rng_idx_factory()
@ -2902,6 +2902,8 @@ class LaxVmapTest(jtu.JaxTestCase):
for padding in ["VALID", "SAME"]
for rng_factory in [jtu.rand_small]))
def testSelectAndGatherAdd(self, dtype, padding, rng_factory):
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
rng = rng_factory()
all_configs = itertools.chain(
itertools.product(