diff --git a/tests/BUILD b/tests/BUILD index 1a8039238..fd3ff163e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1050,6 +1050,9 @@ jax_test( jax_test( name = "python_callback_test", srcs = ["python_callback_test.py"], + disable_configs = [ + "tpu_se", # Host callback does not work on stream executor. + ], deps = [ "//jax:experimental", ], diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 0b0b184ab..ec417018b 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -129,14 +129,16 @@ mlir.register_lowering(callback_p, callback_lowering, platform="tpu") class PythonCallbackTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if xla_bridge.get_backend().runtime_type == 'stream_executor': + raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') + def tearDown(self): super().tearDown() dispatch.runtime_tokens.clear() def test_callback_with_scalar_values(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - @jax.jit def f(x): return callback(lambda x: x + np.float32(1.), @@ -205,8 +207,6 @@ class PythonCallbackTest(jtu.JaxTestCase): jax.effects_barrier() def test_callback_with_single_return_value(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') @jax.jit def f(): @@ -218,8 +218,6 @@ class PythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, np.ones(4, np.float32)) def test_callback_with_multiple_return_values(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') @jax.jit def f(): @@ -233,8 +231,6 @@ class PythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(y, np.ones(5, np.int32)) def test_callback_with_multiple_arguments_and_return_values(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x, y, z): return (x, y + z) @@ -310,8 +306,6 @@ class PythonCallbackTest(jtu.JaxTestCase): self.assertAllClose(res, (result0, result1, result2, result3)) def test_callback_with_pytree_arguments_and_return_values(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return dict(y=[x]) @@ -326,8 +320,6 @@ class PythonCallbackTest(jtu.JaxTestCase): self.assertEqual(out, dict(y=[2.])) def test_callback_inside_of_while_loop_of_scalars(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return (x + 1.).astype(x.dtype) @@ -345,8 +337,6 @@ class PythonCallbackTest(jtu.JaxTestCase): self.assertEqual(out, 10.) def test_callback_inside_of_while_loop(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return (x + 1.).astype(x.dtype) @@ -367,8 +357,6 @@ class PythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.arange(10., 15.)) def test_callback_inside_of_cond_of_scalars(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback1(x): return (x + 1.).astype(x.dtype) @@ -395,8 +383,6 @@ class PythonCallbackTest(jtu.JaxTestCase): self.assertEqual(out, 0.) def test_callback_inside_of_cond(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback1(x): return x + 1. @@ -423,8 +409,6 @@ class PythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.zeros(2)) def test_callback_inside_of_scan_of_scalars(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return (x + 1.).astype(x.dtype) @@ -443,8 +427,6 @@ class PythonCallbackTest(jtu.JaxTestCase): self.assertEqual(out, 10.) def test_callback_inside_of_scan(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return x + 1. @@ -463,8 +445,6 @@ class PythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.arange(2.) + 10.) def test_callback_inside_of_pmap_of_scalars(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return (x + 1.).astype(x.dtype) @@ -479,8 +459,6 @@ class PythonCallbackTest(jtu.JaxTestCase): out, np.arange(jax.local_device_count(), dtype=np.float32) + 1.) def test_callback_inside_of_pmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return x + 1. @@ -499,13 +477,16 @@ class PythonCallbackTest(jtu.JaxTestCase): class PurePythonCallbackTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if xla_bridge.get_backend().runtime_type == 'stream_executor': + raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') + def tearDown(self): super().tearDown() dispatch.runtime_tokens.clear() def test_pure_callback_passes_ndarrays_without_jit(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def cb(x): self.assertIs(type(x), np.ndarray) @@ -516,8 +497,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): f(jnp.array(2.)) def test_simple_pure_callback(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') @jax.jit def f(x): @@ -614,8 +593,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): jax.effects_barrier() def test_can_vmap_pure_callback(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') @jax.jit @jax.vmap @@ -649,8 +626,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): rtol=1E-7, check_dtypes=False) def test_vmap_vectorized_callback(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def cb(x): self.assertTupleEqual(x.shape, ()) @@ -698,8 +673,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): jax.effects_barrier() def test_can_pmap_pure_callback(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') @jax.pmap def f(x): @@ -708,10 +681,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, np.sin(np.arange(jax.local_device_count()))) def test_can_pjit_pure_callback_under_hard_xmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest( - 'Host callback not supported for runtime type: stream_executor.' - ) if not hasattr(xla_client.OpSharding.Type, 'MANUAL'): raise unittest.SkipTest('Manual partitioning needed for pure_callback') @@ -760,8 +729,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): f(2.) def test_can_take_grad_of_pure_callback_with_custom_jvp(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') @jax.custom_jvp def sin(x): @@ -780,8 +747,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.cos(2.)) def test_callback_inside_of_cond(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback1(x): return x + 1. @@ -806,8 +771,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.zeros(2)) def test_callback_inside_of_scan(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return x + 1. @@ -825,8 +788,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.arange(2.) + 10.) def test_callback_inside_of_while_loop(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _cond_callback(x): return np.any(x < 10) @@ -850,8 +811,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.arange(10., 15.)) def test_callback_inside_of_pmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return x + 1. @@ -868,8 +827,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.) def test_callback_inside_xmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return (x + 1.).astype(x.dtype) @@ -884,8 +841,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.arange(1., 41.)) def test_vectorized_callback_inside_xmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def _callback(x): return (x + 1.).astype(x.dtype) @@ -900,8 +855,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.arange(1., 41.)) def test_array_layout_is_preserved(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def g(x): return jax.pure_callback(lambda x: x, x, x) @@ -910,10 +863,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(g(x), x) def test_can_shard_pure_callback_maximally(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest( - 'Host callback not supported for runtime type: stream_executor.' - ) mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) @@ -934,10 +883,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase): ) def test_can_shard_pure_callback_manually(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest( - 'Host callback not supported for runtime type: stream_executor.' - ) mesh = Mesh(np.array(jax.devices()), axis_names=('x',))