Remove stale limitations

PiperOrigin-RevId: 383652551
This commit is contained in:
David Majnemer 2021-07-08 09:41:52 -07:00 committed by jax authors
parent 78a689bb09
commit f0c30492dc
3 changed files with 0 additions and 19 deletions

View File

@ -238,12 +238,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
dtypes=[np.complex64, np.complex128],
devices=("cpu", "gpu"),
modes="compiled"),
missing_tf_kernel(
# Interesting: on TPU, complex64 works in eager
# mode, but fails otherwise.
dtypes=[np.complex64, np.complex128],
devices="tpu",
modes=("graph", "compiled")),
# TODO: very high tolerance
custom_numeric(
dtypes=[np.float32, np.complex64],
@ -518,14 +512,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
check_right_eigenvectors(operand, all_w_tf, all_vr_tf)
return [
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
# TODO(b/181414529): enable after XLA/GPU bug is fixed.
Jax2TfLimitation(
"XLA lowering bug",
dtypes=(np.complex64, np.complex128),
devices=("gpu",),
modes="compiled",
skip_tf_run=True),
missing_tf_kernel(
dtypes=dtypes.bfloat16,
devices="tpu",

View File

@ -569,8 +569,6 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.ignore_warning(message="Using reduced precision for gradient.*")
def testSelectAndGatherAdd(self, dtype, padding):
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
rng = jtu.rand_small(self.rng())
all_configs = itertools.chain(
itertools.product(
@ -626,7 +624,6 @@ class LaxVmapTest(jtu.JaxTestCase):
for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
for bdims in all_bdims(shape)
for fft_ndims in range(0, min(3, len(shape)) + 1)))
@jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases.
def testFft(self, fft_ndims, shape, bdims):
rng = jtu.rand_default(self.rng())
ndims = len(shape)

View File

@ -109,8 +109,6 @@ class NNFunctionsTest(jtu.JaxTestCase):
partial(nn.gelu, approximate=True),
nn.relu, nn.softplus, nn.sigmoid)))
def testDtypeMatchesInput(self, dtype, fn):
if dtype is jnp.float16 and jtu.device_under_test() == "tpu":
self.skipTest("float16 not supported on TPU")
x = jnp.zeros((), dtype=dtype)
out = fn(x)
self.assertEqual(out.dtype, dtype)