change skip_on_devices to handle device tags

This commit is contained in:
Reza Rahimi 2021-07-30 19:17:21 +00:00
parent 59e6efdaef
commit b44d35664c
2 changed files with 23 additions and 3 deletions

View File

@ -423,16 +423,32 @@ def skip_if_unsupported_type(dtype):
raise unittest.SkipTest(
f"Type {dtype.name} not supported on {device_under_test()}")
def is_device_rocm():
return xla_bridge.get_backend().platform_version.startswith('rocm')
def is_device_cuda():
return xla_bridge.get_backend().platform_version.startswith('cuda')
def _get_device_tags():
"""returns a set of tags definded for the device under test"""
if is_device_rocm():
device_tags = set([device_under_test(), "rocm"])
elif is_device_cuda():
device_tags = set([device_under_test(), "cuda"])
else:
device_tags = set([device_under_test()])
return device_tags
def skip_on_devices(*disabled_devices):
"""A decorator for test methods to skip the test on certain devices."""
def skip(test_method):
@functools.wraps(test_method)
def test_method_wrapper(self, *args, **kwargs):
device = device_under_test()
if device in disabled_devices:
device_tags = _get_device_tags()
if device_tags & set(disabled_devices):
test_name = getattr(test_method, '__name__', '[unknown test]')
raise unittest.SkipTest(
f"{test_name} not supported on {device.upper()}.")
f"{test_name} not supported on device with tags {device_tags}.")
return test_method(self, *args, **kwargs)
return test_method_wrapper
return skip

View File

@ -105,6 +105,7 @@ class FftTest(jtu.JaxTestCase):
for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5)]
for axes in _get_fftn_test_axes(shape)
for s in _get_fftn_test_s(shape, axes)))
@jtu.skip_on_devices("rocm")
def testFftn(self, inverse, real, shape, dtype, axes, s):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
@ -123,6 +124,7 @@ class FftTest(jtu.JaxTestCase):
tol = 0.15
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
@jtu.skip_on_devices("rocm")
def testIrfftTranspose(self):
# regression test for https://github.com/google/jax/issues/6223
def build_matrix(linear_func, size):
@ -182,6 +184,7 @@ class FftTest(jtu.JaxTestCase):
for shape in [(10,)]
for n in [None, 1, 7, 13, 20]
for axis in [-1, 0]))
@jtu.skip_on_devices("rocm")
def testFft(self, inverse, real, hermitian, shape, dtype, n, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
@ -246,6 +249,7 @@ class FftTest(jtu.JaxTestCase):
for dtype in (real_dtypes if real and not inverse else all_dtypes)
for shape in [(16, 8, 4, 8), (16, 8, 4, 8, 4)]
for axes in [(-2, -1), (0, 1), (1, 3), (-1, 2)]))
@jtu.skip_on_devices("rocm")
def testFft2(self, inverse, real, shape, dtype, axes):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)