Update shape polymorphism tests to skip lu_pivots_to_permutations tests when jaxlib version is too old.

PiperOrigin-RevId: 662088901
This commit is contained in:
Dan Foreman-Mackey 2024-08-12 08:12:36 -07:00 committed by jax authors
parent 112cae1dad
commit 4eb5ef28ef

View File

@ -49,6 +49,7 @@ from jax._src.export import shape_poly_decision
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib import xla_client
from jax._src.lib import version as jaxlib_version
import numpy as np
config.parse_flags_with_absl()
@ -3395,9 +3396,22 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
"vmap_qr:gpu", "qr:gpu",
"vmap_svd:gpu",
}
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
name_device_key = f"{harness.group_name}:{jtu.device_under_test()}"
if name_device_key in custom_call_harnesses:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
# This list keeps track of the minimum jaxlib version that supports shape
# polymorphism for some new primitives as we add them. This check is
# required so that we can still run the test suite with older versions of
# jaxlib.
version_gated = {
# TODO(danfm): remove these checks when jaxlib 0.4.32 is released.
"lu_pivots_to_permutation:gpu": (0, 4, 32),
"lu_pivots_to_permutation_error:gpu": (0, 4, 32),
}
if version_gated.get(name_device_key, jaxlib_version) > jaxlib_version:
raise unittest.SkipTest(f"shape polymorphism not supported by jaxlib version {jaxlib_version}")
if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")