mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
change skip_on_devices to handle device tags
This commit is contained in:
parent
59e6efdaef
commit
b44d35664c
@ -423,16 +423,32 @@ def skip_if_unsupported_type(dtype):
|
||||
raise unittest.SkipTest(
|
||||
f"Type {dtype.name} not supported on {device_under_test()}")
|
||||
|
||||
def is_device_rocm():
|
||||
return xla_bridge.get_backend().platform_version.startswith('rocm')
|
||||
|
||||
def is_device_cuda():
|
||||
return xla_bridge.get_backend().platform_version.startswith('cuda')
|
||||
|
||||
def _get_device_tags():
|
||||
"""returns a set of tags definded for the device under test"""
|
||||
if is_device_rocm():
|
||||
device_tags = set([device_under_test(), "rocm"])
|
||||
elif is_device_cuda():
|
||||
device_tags = set([device_under_test(), "cuda"])
|
||||
else:
|
||||
device_tags = set([device_under_test()])
|
||||
return device_tags
|
||||
|
||||
def skip_on_devices(*disabled_devices):
|
||||
"""A decorator for test methods to skip the test on certain devices."""
|
||||
def skip(test_method):
|
||||
@functools.wraps(test_method)
|
||||
def test_method_wrapper(self, *args, **kwargs):
|
||||
device = device_under_test()
|
||||
if device in disabled_devices:
|
||||
device_tags = _get_device_tags()
|
||||
if device_tags & set(disabled_devices):
|
||||
test_name = getattr(test_method, '__name__', '[unknown test]')
|
||||
raise unittest.SkipTest(
|
||||
f"{test_name} not supported on {device.upper()}.")
|
||||
f"{test_name} not supported on device with tags {device_tags}.")
|
||||
return test_method(self, *args, **kwargs)
|
||||
return test_method_wrapper
|
||||
return skip
|
||||
|
@ -105,6 +105,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5)]
|
||||
for axes in _get_fftn_test_axes(shape)
|
||||
for s in _get_fftn_test_s(shape, axes)))
|
||||
@jtu.skip_on_devices("rocm")
|
||||
def testFftn(self, inverse, real, shape, dtype, axes, s):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
@ -123,6 +124,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
tol = 0.15
|
||||
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
|
||||
@jtu.skip_on_devices("rocm")
|
||||
def testIrfftTranspose(self):
|
||||
# regression test for https://github.com/google/jax/issues/6223
|
||||
def build_matrix(linear_func, size):
|
||||
@ -182,6 +184,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
for shape in [(10,)]
|
||||
for n in [None, 1, 7, 13, 20]
|
||||
for axis in [-1, 0]))
|
||||
@jtu.skip_on_devices("rocm")
|
||||
def testFft(self, inverse, real, hermitian, shape, dtype, n, axis):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
@ -246,6 +249,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
for dtype in (real_dtypes if real and not inverse else all_dtypes)
|
||||
for shape in [(16, 8, 4, 8), (16, 8, 4, 8, 4)]
|
||||
for axes in [(-2, -1), (0, 1), (1, 3), (-1, 2)]))
|
||||
@jtu.skip_on_devices("rocm")
|
||||
def testFft2(self, inverse, real, shape, dtype, axes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
|
Loading…
x
Reference in New Issue
Block a user