Fix jax.clear_backends() on Cloud TPU

PiperOrigin-RevId: 542279072
This commit is contained in:
Skye Wanderman-Milne 2023-06-21 09:28:20 -07:00 committed by jax authors
parent f46e5141b1
commit 529822cbcf

View File

@ -21,6 +21,7 @@ import jax
from jax import config
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl()
@ -28,7 +29,7 @@ config.parse_flags_with_absl()
class ClearBackendsTest(jtu.JaxTestCase):
def test_clear_backends(self):
if xb.using_pjrt_c_api():
if xla_extension_version < 164 and xb.using_pjrt_c_api():
raise unittest.SkipTest('test crashes runtime with PJRT C API')
g = jax.jit(lambda x, y: x * y)