Skip testAutodiffCache test if xla_extension_version < 123

PiperOrigin-RevId: 507292333
This commit is contained in:
Yash Katariya 2023-02-05 09:38:55 -08:00 committed by jax authors
parent a30ba83db2
commit be67db33d8

View File

@ -565,6 +565,9 @@ class PJitTest(jtu.BufferDonationTestCase):
def testAutodiffCache(self):
if not jax.config.jax_array:
self.skipTest('Does not work without jax.Array')
if xla_extension_version < 123:
self.skipTest('This test requires xla_extension_version >= 123.')
f = pjit(lambda x: jnp.sin(x).sum(),
in_axis_resources=P('x'), out_axis_resources=None)
x = jnp.arange(16, dtype=jnp.float32)