mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Update shape polymorphism tests to skip lu_pivots_to_permutations tests when jaxlib version is too old.
PiperOrigin-RevId: 662088901
This commit is contained in:
parent
112cae1dad
commit
4eb5ef28ef
@ -49,6 +49,7 @@ from jax._src.export import shape_poly_decision
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -3395,9 +3396,22 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
"vmap_qr:gpu", "qr:gpu",
|
||||
"vmap_svd:gpu",
|
||||
}
|
||||
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
|
||||
name_device_key = f"{harness.group_name}:{jtu.device_under_test()}"
|
||||
if name_device_key in custom_call_harnesses:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
|
||||
|
||||
# This list keeps track of the minimum jaxlib version that supports shape
|
||||
# polymorphism for some new primitives as we add them. This check is
|
||||
# required so that we can still run the test suite with older versions of
|
||||
# jaxlib.
|
||||
version_gated = {
|
||||
# TODO(danfm): remove these checks when jaxlib 0.4.32 is released.
|
||||
"lu_pivots_to_permutation:gpu": (0, 4, 32),
|
||||
"lu_pivots_to_permutation_error:gpu": (0, 4, 32),
|
||||
}
|
||||
if version_gated.get(name_device_key, jaxlib_version) > jaxlib_version:
|
||||
raise unittest.SkipTest(f"shape polymorphism not supported by jaxlib version {jaxlib_version}")
|
||||
|
||||
if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]):
|
||||
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user