mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Disable failing test (#1744)
This commit is contained in:
parent
6f3cb1c3ee
commit
3d1d140acd
@ -1335,7 +1335,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
]
|
||||
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
|
||||
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
|
||||
rng_idx_factory):
|
||||
rng = rng_factory()
|
||||
rng_idx = rng_idx_factory()
|
||||
@ -2902,6 +2902,8 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for padding in ["VALID", "SAME"]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testSelectAndGatherAdd(self, dtype, padding, rng_factory):
|
||||
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
|
||||
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
|
||||
rng = rng_factory()
|
||||
all_configs = itertools.chain(
|
||||
itertools.product(
|
||||
|
Loading…
x
Reference in New Issue
Block a user