Add version guard to compilation cache test.

PiperOrigin-RevId: 525572568
This commit is contained in:
Peter Hawkins 2023-04-19 15:49:55 -07:00 committed by jax authors
parent fb5664d580
commit 34fd4a1562

View File

@ -197,6 +197,8 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.get_cache_key(computation2, compile_options, backend),
)
@unittest.skipIf(jax._src.lib.version < (0, 4, 9),
"Test requires jaxlib 0.4.9")
@parameterized.parameters([False, True])
def test_identical_computations_different_metadata(self, include_metadata):
f = lambda x, y: lax.mul(lax.add(x, y), 2)