[callback] Disable stream_executor tests.

PiperOrigin-RevId: 559252832
This commit is contained in:
George Necula 2023-08-22 16:14:25 -07:00 committed by jax authors
parent bf29c5e5f1
commit 26f091e446
2 changed files with 13 additions and 65 deletions

View File

@ -1050,6 +1050,9 @@ jax_test(
jax_test(
name = "python_callback_test",
srcs = ["python_callback_test.py"],
disable_configs = [
"tpu_se", # Host callback does not work on stream executor.
],
deps = [
"//jax:experimental",
],

View File

@ -129,14 +129,16 @@ mlir.register_lowering(callback_p, callback_lowering, platform="tpu")
class PythonCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
def test_callback_with_scalar_values(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
def f(x):
return callback(lambda x: x + np.float32(1.),
@ -205,8 +207,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
jax.effects_barrier()
def test_callback_with_single_return_value(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
def f():
@ -218,8 +218,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, np.ones(4, np.float32))
def test_callback_with_multiple_return_values(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
def f():
@ -233,8 +231,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(y, np.ones(5, np.int32))
def test_callback_with_multiple_arguments_and_return_values(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x, y, z):
return (x, y + z)
@ -310,8 +306,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
self.assertAllClose(res, (result0, result1, result2, result3))
def test_callback_with_pytree_arguments_and_return_values(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return dict(y=[x])
@ -326,8 +320,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
self.assertEqual(out, dict(y=[2.]))
def test_callback_inside_of_while_loop_of_scalars(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return (x + 1.).astype(x.dtype)
@ -345,8 +337,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
self.assertEqual(out, 10.)
def test_callback_inside_of_while_loop(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return (x + 1.).astype(x.dtype)
@ -367,8 +357,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.arange(10., 15.))
def test_callback_inside_of_cond_of_scalars(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback1(x):
return (x + 1.).astype(x.dtype)
@ -395,8 +383,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
self.assertEqual(out, 0.)
def test_callback_inside_of_cond(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback1(x):
return x + 1.
@ -423,8 +409,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.zeros(2))
def test_callback_inside_of_scan_of_scalars(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return (x + 1.).astype(x.dtype)
@ -443,8 +427,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
self.assertEqual(out, 10.)
def test_callback_inside_of_scan(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return x + 1.
@ -463,8 +445,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
def test_callback_inside_of_pmap_of_scalars(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return (x + 1.).astype(x.dtype)
@ -479,8 +459,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
out, np.arange(jax.local_device_count(), dtype=np.float32) + 1.)
def test_callback_inside_of_pmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return x + 1.
@ -499,13 +477,16 @@ class PythonCallbackTest(jtu.JaxTestCase):
class PurePythonCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
def test_pure_callback_passes_ndarrays_without_jit(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def cb(x):
self.assertIs(type(x), np.ndarray)
@ -516,8 +497,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
f(jnp.array(2.))
def test_simple_pure_callback(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
def f(x):
@ -614,8 +593,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
jax.effects_barrier()
def test_can_vmap_pure_callback(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
@jax.vmap
@ -649,8 +626,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
rtol=1E-7, check_dtypes=False)
def test_vmap_vectorized_callback(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def cb(x):
self.assertTupleEqual(x.shape, ())
@ -698,8 +673,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
jax.effects_barrier()
def test_can_pmap_pure_callback(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.pmap
def f(x):
@ -708,10 +681,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, np.sin(np.arange(jax.local_device_count())))
def test_can_pjit_pure_callback_under_hard_xmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest(
'Host callback not supported for runtime type: stream_executor.'
)
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
raise unittest.SkipTest('Manual partitioning needed for pure_callback')
@ -760,8 +729,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
f(2.)
def test_can_take_grad_of_pure_callback_with_custom_jvp(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.custom_jvp
def sin(x):
@ -780,8 +747,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.cos(2.))
def test_callback_inside_of_cond(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback1(x):
return x + 1.
@ -806,8 +771,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.zeros(2))
def test_callback_inside_of_scan(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return x + 1.
@ -825,8 +788,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
def test_callback_inside_of_while_loop(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _cond_callback(x):
return np.any(x < 10)
@ -850,8 +811,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.arange(10., 15.))
def test_callback_inside_of_pmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return x + 1.
@ -868,8 +827,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
def test_callback_inside_xmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return (x + 1.).astype(x.dtype)
@ -884,8 +841,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.arange(1., 41.))
def test_vectorized_callback_inside_xmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def _callback(x):
return (x + 1.).astype(x.dtype)
@ -900,8 +855,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.arange(1., 41.))
def test_array_layout_is_preserved(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def g(x):
return jax.pure_callback(lambda x: x, x, x)
@ -910,10 +863,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(g(x), x)
def test_can_shard_pure_callback_maximally(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest(
'Host callback not supported for runtime type: stream_executor.'
)
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
@ -934,10 +883,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
)
def test_can_shard_pure_callback_manually(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest(
'Host callback not supported for runtime type: stream_executor.'
)
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))