mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove stale limitations
PiperOrigin-RevId: 383652551
This commit is contained in:
parent
78a689bb09
commit
f0c30492dc
@ -238,12 +238,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
dtypes=[np.complex64, np.complex128],
|
dtypes=[np.complex64, np.complex128],
|
||||||
devices=("cpu", "gpu"),
|
devices=("cpu", "gpu"),
|
||||||
modes="compiled"),
|
modes="compiled"),
|
||||||
missing_tf_kernel(
|
|
||||||
# Interesting: on TPU, complex64 works in eager
|
|
||||||
# mode, but fails otherwise.
|
|
||||||
dtypes=[np.complex64, np.complex128],
|
|
||||||
devices="tpu",
|
|
||||||
modes=("graph", "compiled")),
|
|
||||||
# TODO: very high tolerance
|
# TODO: very high tolerance
|
||||||
custom_numeric(
|
custom_numeric(
|
||||||
dtypes=[np.float32, np.complex64],
|
dtypes=[np.float32, np.complex64],
|
||||||
@ -518,14 +512,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
check_right_eigenvectors(operand, all_w_tf, all_vr_tf)
|
check_right_eigenvectors(operand, all_w_tf, all_vr_tf)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
|
|
||||||
# TODO(b/181414529): enable after XLA/GPU bug is fixed.
|
|
||||||
Jax2TfLimitation(
|
|
||||||
"XLA lowering bug",
|
|
||||||
dtypes=(np.complex64, np.complex128),
|
|
||||||
devices=("gpu",),
|
|
||||||
modes="compiled",
|
|
||||||
skip_tf_run=True),
|
|
||||||
missing_tf_kernel(
|
missing_tf_kernel(
|
||||||
dtypes=dtypes.bfloat16,
|
dtypes=dtypes.bfloat16,
|
||||||
devices="tpu",
|
devices="tpu",
|
||||||
|
@ -569,8 +569,6 @@ class LaxVmapTest(jtu.JaxTestCase):
|
|||||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||||
@jtu.ignore_warning(message="Using reduced precision for gradient.*")
|
@jtu.ignore_warning(message="Using reduced precision for gradient.*")
|
||||||
def testSelectAndGatherAdd(self, dtype, padding):
|
def testSelectAndGatherAdd(self, dtype, padding):
|
||||||
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
|
|
||||||
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
|
|
||||||
rng = jtu.rand_small(self.rng())
|
rng = jtu.rand_small(self.rng())
|
||||||
all_configs = itertools.chain(
|
all_configs = itertools.chain(
|
||||||
itertools.product(
|
itertools.product(
|
||||||
@ -626,7 +624,6 @@ class LaxVmapTest(jtu.JaxTestCase):
|
|||||||
for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
|
for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
|
||||||
for bdims in all_bdims(shape)
|
for bdims in all_bdims(shape)
|
||||||
for fft_ndims in range(0, min(3, len(shape)) + 1)))
|
for fft_ndims in range(0, min(3, len(shape)) + 1)))
|
||||||
@jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases.
|
|
||||||
def testFft(self, fft_ndims, shape, bdims):
|
def testFft(self, fft_ndims, shape, bdims):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
ndims = len(shape)
|
ndims = len(shape)
|
||||||
|
@ -109,8 +109,6 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|||||||
partial(nn.gelu, approximate=True),
|
partial(nn.gelu, approximate=True),
|
||||||
nn.relu, nn.softplus, nn.sigmoid)))
|
nn.relu, nn.softplus, nn.sigmoid)))
|
||||||
def testDtypeMatchesInput(self, dtype, fn):
|
def testDtypeMatchesInput(self, dtype, fn):
|
||||||
if dtype is jnp.float16 and jtu.device_under_test() == "tpu":
|
|
||||||
self.skipTest("float16 not supported on TPU")
|
|
||||||
x = jnp.zeros((), dtype=dtype)
|
x = jnp.zeros((), dtype=dtype)
|
||||||
out = fn(x)
|
out = fn(x)
|
||||||
self.assertEqual(out.dtype, dtype)
|
self.assertEqual(out.dtype, dtype)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user