mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Fix jax.clear_backends()
on Cloud TPU
PiperOrigin-RevId: 542279072
This commit is contained in:
parent
f46e5141b1
commit
529822cbcf
@ -21,6 +21,7 @@ import jax
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -28,7 +29,7 @@ config.parse_flags_with_absl()
|
||||
class ClearBackendsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_clear_backends(self):
|
||||
if xb.using_pjrt_c_api():
|
||||
if xla_extension_version < 164 and xb.using_pjrt_c_api():
|
||||
raise unittest.SkipTest('test crashes runtime with PJRT C API')
|
||||
|
||||
g = jax.jit(lambda x, y: x * y)
|
||||
|
Loading…
x
Reference in New Issue
Block a user