diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 98477e89d..ec3e376a3 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -29,7 +29,6 @@ from jax._src import dispatch from jax._src import maps from jax._src import test_util as jtu from jax._src import util -from jax._src import xla_bridge from jax._src.lib import xla_client from jax._src.lib import xla_extension_version from jax.experimental import io_callback @@ -79,8 +78,6 @@ class PythonCallbackTest(jtu.JaxTestCase): super().setUp() if not jtu.test_device_matches(["cpu", "gpu", "tpu"]): self.skipTest(f"Host callback not supported on {jtu.device_under_test()}") - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def tearDown(self): super().tearDown() @@ -504,8 +501,6 @@ class PureCallbackTest(jtu.JaxTestCase): super().setUp() if not jtu.test_device_matches(["cpu", "gpu", "tpu"]): self.skipTest(f"Host callback not supported on {jtu.device_under_test()}") - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def tearDown(self): super().tearDown() @@ -911,8 +906,6 @@ class IOCallbackTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') if not jtu.test_device_matches(["cpu", "gpu", "tpu"]): self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")