mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove stale limitations
PiperOrigin-RevId: 383652551
This commit is contained in:
parent
78a689bb09
commit
f0c30492dc
@ -238,12 +238,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
dtypes=[np.complex64, np.complex128],
|
||||
devices=("cpu", "gpu"),
|
||||
modes="compiled"),
|
||||
missing_tf_kernel(
|
||||
# Interesting: on TPU, complex64 works in eager
|
||||
# mode, but fails otherwise.
|
||||
dtypes=[np.complex64, np.complex128],
|
||||
devices="tpu",
|
||||
modes=("graph", "compiled")),
|
||||
# TODO: very high tolerance
|
||||
custom_numeric(
|
||||
dtypes=[np.float32, np.complex64],
|
||||
@ -518,14 +512,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
check_right_eigenvectors(operand, all_w_tf, all_vr_tf)
|
||||
|
||||
return [
|
||||
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
|
||||
# TODO(b/181414529): enable after XLA/GPU bug is fixed.
|
||||
Jax2TfLimitation(
|
||||
"XLA lowering bug",
|
||||
dtypes=(np.complex64, np.complex128),
|
||||
devices=("gpu",),
|
||||
modes="compiled",
|
||||
skip_tf_run=True),
|
||||
missing_tf_kernel(
|
||||
dtypes=dtypes.bfloat16,
|
||||
devices="tpu",
|
||||
|
@ -569,8 +569,6 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
@jtu.ignore_warning(message="Using reduced precision for gradient.*")
|
||||
def testSelectAndGatherAdd(self, dtype, padding):
|
||||
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
|
||||
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
|
||||
rng = jtu.rand_small(self.rng())
|
||||
all_configs = itertools.chain(
|
||||
itertools.product(
|
||||
@ -626,7 +624,6 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
|
||||
for bdims in all_bdims(shape)
|
||||
for fft_ndims in range(0, min(3, len(shape)) + 1)))
|
||||
@jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases.
|
||||
def testFft(self, fft_ndims, shape, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
ndims = len(shape)
|
||||
|
@ -109,8 +109,6 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
partial(nn.gelu, approximate=True),
|
||||
nn.relu, nn.softplus, nn.sigmoid)))
|
||||
def testDtypeMatchesInput(self, dtype, fn):
|
||||
if dtype is jnp.float16 and jtu.device_under_test() == "tpu":
|
||||
self.skipTest("float16 not supported on TPU")
|
||||
x = jnp.zeros((), dtype=dtype)
|
||||
out = fn(x)
|
||||
self.assertEqual(out.dtype, dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user