Enable python_callback_test for stream executor.

python_callback_test is supported for GPU stream executor. TPU stream executor was deprecated.

PiperOrigin-RevId: 578960299
This commit is contained in:
Jieying Luo 2023-11-02 13:26:13 -07:00 committed by jax authors
parent c8b7c1b80b
commit c9db50cfd0

View File

@ -29,7 +29,6 @@ from jax._src import dispatch
from jax._src import maps
from jax._src import test_util as jtu
from jax._src import util
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax.experimental import io_callback
@ -79,8 +78,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
super().setUp()
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
@ -504,8 +501,6 @@ class PureCallbackTest(jtu.JaxTestCase):
super().setUp()
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
@ -911,8 +906,6 @@ class IOCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")