mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[callback] Disable stream_executor tests.
PiperOrigin-RevId: 559252832
This commit is contained in:
parent
bf29c5e5f1
commit
26f091e446
@ -1050,6 +1050,9 @@ jax_test(
|
||||
jax_test(
|
||||
name = "python_callback_test",
|
||||
srcs = ["python_callback_test.py"],
|
||||
disable_configs = [
|
||||
"tpu_se", # Host callback does not work on stream executor.
|
||||
],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
|
@ -129,14 +129,16 @@ mlir.register_lowering(callback_p, callback_lowering, platform="tpu")
|
||||
|
||||
class PythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
|
||||
def test_callback_with_scalar_values(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return callback(lambda x: x + np.float32(1.),
|
||||
@ -205,8 +207,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
jax.effects_barrier()
|
||||
|
||||
def test_callback_with_single_return_value(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
@ -218,8 +218,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, np.ones(4, np.float32))
|
||||
|
||||
def test_callback_with_multiple_return_values(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
@ -233,8 +231,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(y, np.ones(5, np.int32))
|
||||
|
||||
def test_callback_with_multiple_arguments_and_return_values(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x, y, z):
|
||||
return (x, y + z)
|
||||
@ -310,8 +306,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(res, (result0, result1, result2, result3))
|
||||
|
||||
def test_callback_with_pytree_arguments_and_return_values(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return dict(y=[x])
|
||||
@ -326,8 +320,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, dict(y=[2.]))
|
||||
|
||||
def test_callback_inside_of_while_loop_of_scalars(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
@ -345,8 +337,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, 10.)
|
||||
|
||||
def test_callback_inside_of_while_loop(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
@ -367,8 +357,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.arange(10., 15.))
|
||||
|
||||
def test_callback_inside_of_cond_of_scalars(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback1(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
@ -395,8 +383,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, 0.)
|
||||
|
||||
def test_callback_inside_of_cond(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback1(x):
|
||||
return x + 1.
|
||||
@ -423,8 +409,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.zeros(2))
|
||||
|
||||
def test_callback_inside_of_scan_of_scalars(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
@ -443,8 +427,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, 10.)
|
||||
|
||||
def test_callback_inside_of_scan(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return x + 1.
|
||||
@ -463,8 +445,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
|
||||
|
||||
def test_callback_inside_of_pmap_of_scalars(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
@ -479,8 +459,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
out, np.arange(jax.local_device_count(), dtype=np.float32) + 1.)
|
||||
|
||||
def test_callback_inside_of_pmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return x + 1.
|
||||
@ -499,13 +477,16 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
|
||||
def test_pure_callback_passes_ndarrays_without_jit(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def cb(x):
|
||||
self.assertIs(type(x), np.ndarray)
|
||||
@ -516,8 +497,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
f(jnp.array(2.))
|
||||
|
||||
def test_simple_pure_callback(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@ -614,8 +593,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
jax.effects_barrier()
|
||||
|
||||
def test_can_vmap_pure_callback(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
@ -649,8 +626,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
rtol=1E-7, check_dtypes=False)
|
||||
|
||||
def test_vmap_vectorized_callback(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def cb(x):
|
||||
self.assertTupleEqual(x.shape, ())
|
||||
@ -698,8 +673,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
jax.effects_barrier()
|
||||
|
||||
def test_can_pmap_pure_callback(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
@ -708,10 +681,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(jax.local_device_count())))
|
||||
|
||||
def test_can_pjit_pure_callback_under_hard_xmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest(
|
||||
'Host callback not supported for runtime type: stream_executor.'
|
||||
)
|
||||
|
||||
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
|
||||
raise unittest.SkipTest('Manual partitioning needed for pure_callback')
|
||||
@ -760,8 +729,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
f(2.)
|
||||
|
||||
def test_can_take_grad_of_pure_callback_with_custom_jvp(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.custom_jvp
|
||||
def sin(x):
|
||||
@ -780,8 +747,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.cos(2.))
|
||||
|
||||
def test_callback_inside_of_cond(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback1(x):
|
||||
return x + 1.
|
||||
@ -806,8 +771,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.zeros(2))
|
||||
|
||||
def test_callback_inside_of_scan(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return x + 1.
|
||||
@ -825,8 +788,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
|
||||
|
||||
def test_callback_inside_of_while_loop(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _cond_callback(x):
|
||||
return np.any(x < 10)
|
||||
@ -850,8 +811,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.arange(10., 15.))
|
||||
|
||||
def test_callback_inside_of_pmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return x + 1.
|
||||
@ -868,8 +827,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
|
||||
|
||||
def test_callback_inside_xmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
@ -884,8 +841,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.arange(1., 41.))
|
||||
|
||||
def test_vectorized_callback_inside_xmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def _callback(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
@ -900,8 +855,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.arange(1., 41.))
|
||||
|
||||
def test_array_layout_is_preserved(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def g(x):
|
||||
return jax.pure_callback(lambda x: x, x, x)
|
||||
@ -910,10 +863,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(g(x), x)
|
||||
|
||||
def test_can_shard_pure_callback_maximally(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest(
|
||||
'Host callback not supported for runtime type: stream_executor.'
|
||||
)
|
||||
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
||||
|
||||
@ -934,10 +883,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_can_shard_pure_callback_manually(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest(
|
||||
'Host callback not supported for runtime type: stream_executor.'
|
||||
)
|
||||
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user