[shape_poly] Minor cleanup

This commit is contained in:
George Necula 2023-08-12 08:25:45 +03:00
parent cf4e1d414b
commit b90a7b7539

View File

@ -1044,7 +1044,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# Tests details of the shape constraints errors.
# This test exists also in jax_export_test.py.
@jtu.parameterized_filterable(
#one_containing="",
testcase_name=lambda kw: kw["shape"],
kwargs=[
dict(shape=(8, 2, 9), # a = 2, b = 3, c = 4
@ -2158,11 +2157,11 @@ _POLY_SHAPE_TEST_HARNESSES = [
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
[
[
PolyHarness(cum_name, "reduce_axis=poly",
PolyHarness(cum_name, "reduce_axis_poly",
lambda x: cum_func(x, axis=0),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness(cum_name, "reduce_axis=static",
PolyHarness(cum_name, "reduce_axis_static",
lambda x: cum_func(x, axis=1),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."])
@ -3103,17 +3102,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")
# Some tests need the latest jaxlib
need_new_jaxlib = []
if jaxlib_version < (0, 4, 13):
need_new_jaxlib.append("fft")
elif jaxlib_version < (0, 4, 14):
need_new_jaxlib.extend(("lu", "vmap_lu", "custom_linear_solve",
"vmap_custom_linear_solve",
"vmap_approx_top_k", "schur"))
if harness.group_name in need_new_jaxlib:
raise unittest.SkipTest("native lowering with shape polymorphism needs newer jaxlib")
if (jtu.device_under_test() in ["cpu", "gpu"] and
harness.fullname in [
"cumsum_reduce_axis_poly", "cumprod_reduce_axis_poly",
@ -3121,7 +3109,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
"cumlogsumexp_reduce_axis_poly",
"jnp_insert_insert_constant", "jnp_insert_insert_poly",
"jnp_nonzero_size_constant", "jnp_nonzero_size_poly"]):
# Need associative scan reductions on CPU and GPU
# Need associative scan reductions on CPU and GPU. On TPU we use the
# reduce_window HLO, but on CPU and GPU (with axis size >= 32) we use
# a recursive associative scan that we cannot express with shape
# polymorphism.
raise unittest.SkipTest(
"native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU")