From c9db50cfd0fcfba31c85af1e929504ee4329b5b8 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 2 Nov 2023 13:26:13 -0700 Subject: [PATCH] Enable python_callback_test for stream executor. python_callback_test is supported for GPU stream executor. TPU stream executor was deprecated. PiperOrigin-RevId: 578960299 --- tests/python_callback_test.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 98477e89d..ec3e376a3 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -29,7 +29,6 @@ from jax._src import dispatch from jax._src import maps from jax._src import test_util as jtu from jax._src import util -from jax._src import xla_bridge from jax._src.lib import xla_client from jax._src.lib import xla_extension_version from jax.experimental import io_callback @@ -79,8 +78,6 @@ class PythonCallbackTest(jtu.JaxTestCase): super().setUp() if not jtu.test_device_matches(["cpu", "gpu", "tpu"]): self.skipTest(f"Host callback not supported on {jtu.device_under_test()}") - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def tearDown(self): super().tearDown() @@ -504,8 +501,6 @@ class PureCallbackTest(jtu.JaxTestCase): super().setUp() if not jtu.test_device_matches(["cpu", "gpu", "tpu"]): self.skipTest(f"Host callback not supported on {jtu.device_under_test()}") - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') def tearDown(self): super().tearDown() @@ -911,8 +906,6 @@ class IOCallbackTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') if not jtu.test_device_matches(["cpu", "gpu", "tpu"]): self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")