Test changes for out-of-tree backend.

This commit is contained in:
Jake Hall 2023-09-13 16:35:02 +01:00
parent 5a15ba90db
commit f59a4163fa
12 changed files with 69 additions and 21 deletions

View File

@ -327,13 +327,13 @@ def _get_device_tags():
device_tags = {device_under_test()}
return device_tags
def skip_on_devices(*disabled_devices):
"""A decorator for test methods to skip the test on certain devices."""
def _device_filter(predicate):
def skip(test_method):
@functools.wraps(test_method)
def test_method_wrapper(self, *args, **kwargs):
device_tags = _get_device_tags()
if device_tags & set(disabled_devices):
if not predicate(device_tags):
test_name = getattr(test_method, '__name__', '[unknown test]')
raise unittest.SkipTest(
f"{test_name} not supported on device with tags {device_tags}.")
@ -341,6 +341,26 @@ def skip_on_devices(*disabled_devices):
return test_method_wrapper
return skip
def skip_on_devices(*disabled_devices):
"""A decorator for test methods to skip the test on certain devices."""
def predicate(device_tags):
return not(device_tags & set(disabled_devices))
return _device_filter(predicate)
def run_on_devices(*enabled_devices):
"""A decorator for test methods to run the test only on certain devices."""
def predicate(device_tags):
return device_tags & set(enabled_devices)
return _device_filter(predicate)
def device_supports_buffer_donation():
"""A decorator for test methods to run the test only on devices that support
buffer donation."""
def predicate(device_tags):
return device_tags & set(mlir._platforms_with_donation)
return _device_filter(predicate)
def set_host_platform_device_count(nr_devices: int):
"""Returns a closure that undoes the operation."""
prev_xla_flags = os.getenv("XLA_FLAGS")

View File

@ -43,7 +43,7 @@ with contextlib.suppress(ImportError):
class JaxAotTest(jtu.JaxTestCase):
@jtu.skip_on_devices('cpu', 'gpu')
@jtu.run_on_devices('tpu')
def test_pickle_pjit_lower(self):
if jtu.is_se_tpu():
raise unittest.SkipTest('StreamExecutor not supported.')

View File

@ -420,6 +420,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
("argnums", "donate_argnums", 0),
("argnames", "donate_argnames", 'x'),
)
@jtu.device_supports_buffer_donation()
def test_jit_donate_invalidates_input(self, argnum_type, argnum_val):
# We can't just use `lambda x: x` because JAX simplifies this away to an
# empty XLA computation.
@ -433,6 +434,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
("donate_argnums", "donate_argnums", (2, 3)),
("donate_argnames", "donate_argnames", ('c', 'd')),
)
@jtu.device_supports_buffer_donation()
def test_jit_donate_static_argnums(self, argnum_type, argnum_val):
jit_fun = self.jit(
lambda a, b, c, d: ((a + b + c), (a + b + d)),
@ -447,6 +449,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertDeleted(c)
self.assertDeleted(d)
@jtu.device_supports_buffer_donation()
def test_jit_donate_argnames_kwargs_static_argnums(self):
jit_fun = self.jit(
lambda a, b, c, d, e: ((a + b + c), (a + b + d), (a + b + e)),
@ -468,6 +471,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
("argnums", "donate_argnums", 0),
("argnames", "donate_argnames", 'x'),
)
@jtu.device_supports_buffer_donation()
def test_jit_donate_weak_type(self, argnum_type, argnum_val):
# input has weak-type, output does not have weak-type
move = self.jit(lambda x: x.astype(int), **{argnum_type: argnum_val})
@ -495,6 +499,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
# Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
print(x_copy) # doesn't crash
@jtu.device_supports_buffer_donation()
def test_specify_donate_argnums_and_argnames(self):
@partial(jax.jit, donate_argnums=0, donate_argnames=('inp2', 'inp3'))
def f(inp1, inp2, inp3):
@ -512,6 +517,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
def test_resolve_argnums_signature_fail(self):
api_util.resolve_argnums(int, None, None, None, None) # doesn't crash
@jtu.device_supports_buffer_donation()
def test_donate_argnames_with_args(self):
@partial(jax.jit, donate_argnames='inp1')
def f(inp1):
@ -521,6 +527,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f(x)
self.assertDeleted(x)
@jtu.device_supports_buffer_donation()
def test_donate_argnums_with_kwargs(self):
@partial(jax.jit, donate_argnums=0)
def f(inp1):
@ -10074,6 +10081,7 @@ class CustomApiTest(jtu.JaxTestCase):
class BufferDonationTest(jtu.BufferDonationTestCase):
@jtu.device_supports_buffer_donation()
def test_pmap_donate_argnums_invalidates_input(self):
move = api.pmap(lambda x: x + x - x, donate_argnums=0)
n = jax.local_device_count()
@ -10082,6 +10090,7 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
self.assertDeleted(x)
np.testing.assert_allclose(y, [1.] * n)
@jtu.device_supports_buffer_donation()
def test_pmap_nested_donate_ignored(self):
pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x))
a = api.pmap(lambda x: x)(jnp.array([1]))
@ -10194,7 +10203,7 @@ class BackendsTest(jtu.JaxTestCase):
@unittest.skipIf(not sys.executable, "test requires sys.executable")
@unittest.skipIf(platform.system() == "Darwin",
"Warning doesn't apply on Mac")
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def test_cpu_warning_suppression(self):
warning_expected = (
"import jax; "

View File

@ -62,8 +62,8 @@ all_shapes = nonempty_array_shapes + empty_array_shapes
class DLPackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.device_under_test() == "tpu":
self.skipTest("DLPack not supported on TPU")
if jtu.device_under_test() not in ["cpu", "gpu"]:
self.skipTest(f"DLPack not supported on {jtu.device_under_test()}")
@jtu.sample_product(
shape=all_shapes,

View File

@ -313,11 +313,11 @@ class JaxArrayTest(jtu.JaxTestCase):
def test_wrong_num_arrays(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
devices = jax.local_devices()[:8] # Taking up to 8 devices
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
di_map = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[di_map[d]], d)
for d in jax.local_devices()]
bufs = [jax.device_put(inp_data[di_map[d]], d) for d in devices]
with self.assertRaisesRegex(
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
@ -910,8 +910,9 @@ class ShardingTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, pspec)
devices = jax.local_devices()[:8] # Taking up to 8 devices
devices_sharding = jax.sharding.PositionalSharding(jax.devices())
devices_sharding = jax.sharding.PositionalSharding(devices)
devices_sharding = devices_sharding.reshape(shape).replicate(axes)
if transpose:
devices_sharding = devices_sharding.T

View File

@ -520,7 +520,8 @@ class CheckifyTransformTests(jtu.JaxTestCase):
# binary func
return x / y
mesh = jax.sharding.Mesh(np.array(jax.devices()), ["dev"])
devices = jax.local_devices()[:8] # Taking up to 8 devices
mesh = jax.sharding.Mesh(np.array(devices), ["dev"])
ps = NamedSharding(mesh, jax.sharding.PartitionSpec("dev"))
inp = np.arange(8)
x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx])
@ -1311,7 +1312,7 @@ class LowerableChecksTest(jtu.JaxTestCase):
config.update("jax_experimental_unsafe_xla_runtime_errors", self.prev)
super().tearDown()
@jtu.skip_on_devices("tpu")
@jtu.run_on_devices("cpu", "gpu")
def test_jit(self):
@jax.jit
def f(x):

View File

@ -57,6 +57,11 @@ foo = 2
class CliDebuggerTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.device_under_test() not in ["cpu", "gpu", "tpu"]:
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
def test_debugger_eof(self):
stdin, stdout = make_fake_stdin_stdout([])

View File

@ -36,7 +36,7 @@ class RnnTest(jtu.JaxTestCase):
num_layers=[1, 4],
bidirectional=[True, False],
)
@jtu.skip_on_devices("cpu", "tpu", "rocm")
@jtu.run_on_devices("cuda")
def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
hidden_size: int, num_layers: int, bidirectional: bool):
if lib.version < (0, 4, 7):

View File

@ -305,6 +305,7 @@ class PJitTest(jtu.BufferDonationTestCase):
check_dtypes=False)
@jtu.with_mesh([('x', 2)])
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
def testBufferDonation(self):
@partial(pjit, in_shardings=P('x'), out_shardings=P('x'), donate_argnums=0)
def f(x, y):
@ -318,6 +319,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertNotDeleted(y)
self.assertDeleted(x)
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
def testBufferDonationWithNames(self):
mesh = jtu.create_global_mesh((2,), ('x'))
s = NamedSharding(mesh, P('x'))
@ -333,6 +335,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertNotDeleted(x)
self.assertDeleted(y)
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
def testBufferDonationWithKwargs(self):
mesh = jtu.create_global_mesh((2,), ('x'))
s = NamedSharding(mesh, P('x'))
@ -351,6 +354,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertDeleted(y)
self.assertDeleted(z)
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
def testBufferDonationWithPyTreeKwargs(self):
mesh = jtu.create_global_mesh((2,), ('x'))
s = NamedSharding(mesh, P('x'))
@ -2045,9 +2049,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
# Explicitly put on the ordering of devices which does not match the mesh
# ordering to make sure we reorder them in the constructor and the output
# is correct.
local_devices = jax.local_devices()[:8] # Take 8 local devices
di_map = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[di_map[d]], d)
for d in jax.local_devices()]
bufs = [jax.device_put(inp_data[di_map[d]], d) for d in local_devices]
arr = array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
f = pjit(lambda x: x, out_shardings=s)
@ -2793,7 +2797,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertLen(final_out.sharding.device_set, 1)
self.assertArraysEqual(final_out, inp * 6)
@jtu.skip_on_devices("gpu", "cpu")
@jtu.run_on_devices("tpu")
def test_pjit_with_backend_arg(self):
def _check(out, expected_device, expected_out):
self.assertEqual(out.device(), expected_device)

View File

@ -73,6 +73,8 @@ class PythonCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.device_under_test() not in ["cpu", "gpu", "tpu"]:
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@ -496,6 +498,8 @@ class PureCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.device_under_test() not in ["cpu", "gpu", "tpu"]:
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@ -879,6 +883,8 @@ class IOCallbackTest(jtu.JaxTestCase):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jtu.device_under_test() not in ["cpu", "gpu", "tpu"]:
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
def tearDown(self):
super().tearDown()

View File

@ -158,7 +158,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertEqual(c.device_buffers[0].shape, (2, 8))
def test_collective_permute(self):
devices = np.array(jax.devices())
devices = np.array(jax.devices()[:8]) # Take up to 8 devices
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
@ -176,7 +176,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(c[1, :], a[0, :])
def test_all_to_all(self):
devices = np.array(jax.devices())
devices = np.array(jax.devices()[:8]) # Take up to 8 devices
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
@ -416,7 +416,7 @@ class ShardMapTest(jtu.JaxTestCase):
y_dot_expected = jnp.sin(jnp.arange(8.)) * (jnp.cos(x) * x).sum()
self.assertAllClose(y_dot, y_dot_expected, check_dtypes=False)
@jtu.skip_on_devices("cpu")
@jtu.run_on_devices('gpu', 'tpu')
def test_axis_index(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
@ -522,6 +522,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertIn('out_names', e.params)
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
def test_debug_print_jit(self):
mesh = Mesh(jax.devices(), ('i',))
@ -745,6 +746,7 @@ class ShardMapTest(jtu.JaxTestCase):
# error!
jax.jit(g)(x) # doesn't crash
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
def test_key_array_with_replicated_last_tile_dim(self):
# See https://github.com/google/jax/issues/16137

View File

@ -459,7 +459,7 @@ class XMapTest(XMapTestCase):
self.assertAllClose(f_mapped(x, x), expected)
@jtu.with_and_without_mesh
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
@jtu.run_on_devices("gpu", "tpu") # In/out aliasing not supported on CPU.
def testBufferDonation(self, mesh, axis_resources):
shard = lambda x: x
if axis_resources:
@ -476,7 +476,7 @@ class XMapTest(XMapTestCase):
self.assertNotDeleted(y)
self.assertDeleted(x)
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
@jtu.run_on_devices("gpu", "tpu") # In/out aliasing not supported on CPU.
@jtu.with_mesh([('x', 2)])
@jtu.ignore_warning(category=UserWarning, # SPMD test generates warning.
message="Some donated buffers were not usable*")