From f0c30492dc152b4a39f22f7c06a732229bfc3def Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Thu, 8 Jul 2021 09:41:52 -0700 Subject: [PATCH] Remove stale limitations PiperOrigin-RevId: 383652551 --- .../jax2tf/tests/jax2tf_limitations.py | 14 -------------- tests/lax_vmap_test.py | 3 --- tests/nn_test.py | 2 -- 3 files changed, 19 deletions(-) diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 7c881670e..b7f2de3ee 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -238,12 +238,6 @@ class Jax2TfLimitation(primitive_harness.Limitation): dtypes=[np.complex64, np.complex128], devices=("cpu", "gpu"), modes="compiled"), - missing_tf_kernel( - # Interesting: on TPU, complex64 works in eager - # mode, but fails otherwise. - dtypes=[np.complex64, np.complex128], - devices="tpu", - modes=("graph", "compiled")), # TODO: very high tolerance custom_numeric( dtypes=[np.float32, np.complex64], @@ -518,14 +512,6 @@ class Jax2TfLimitation(primitive_harness.Limitation): check_right_eigenvectors(operand, all_w_tf, all_vr_tf) return [ - # See https://github.com/google/jax/pull/3775#issuecomment-659407824; - # TODO(b/181414529): enable after XLA/GPU bug is fixed. - Jax2TfLimitation( - "XLA lowering bug", - dtypes=(np.complex64, np.complex128), - devices=("gpu",), - modes="compiled", - skip_tf_run=True), missing_tf_kernel( dtypes=dtypes.bfloat16, devices="tpu", diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 998fe97ad..9838afb38 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -569,8 +569,6 @@ class LaxVmapTest(jtu.JaxTestCase): @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(message="Using reduced precision for gradient.*") def testSelectAndGatherAdd(self, dtype, padding): - if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: - raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu") rng = jtu.rand_small(self.rng()) all_configs = itertools.chain( itertools.product( @@ -626,7 +624,6 @@ class LaxVmapTest(jtu.JaxTestCase): for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)] for bdims in all_bdims(shape) for fft_ndims in range(0, min(3, len(shape)) + 1))) - @jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases. def testFft(self, fft_ndims, shape, bdims): rng = jtu.rand_default(self.rng()) ndims = len(shape) diff --git a/tests/nn_test.py b/tests/nn_test.py index 7f207ca05..9b672d47c 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -109,8 +109,6 @@ class NNFunctionsTest(jtu.JaxTestCase): partial(nn.gelu, approximate=True), nn.relu, nn.softplus, nn.sigmoid))) def testDtypeMatchesInput(self, dtype, fn): - if dtype is jnp.float16 and jtu.device_under_test() == "tpu": - self.skipTest("float16 not supported on TPU") x = jnp.zeros((), dtype=dtype) out = fn(x) self.assertEqual(out.dtype, dtype)