mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Minor cleanup
This commit is contained in:
parent
cf4e1d414b
commit
b90a7b7539
@ -1044,7 +1044,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# Tests details of the shape constraints errors.
|
||||
# This test exists also in jax_export_test.py.
|
||||
@jtu.parameterized_filterable(
|
||||
#one_containing="",
|
||||
testcase_name=lambda kw: kw["shape"],
|
||||
kwargs=[
|
||||
dict(shape=(8, 2, 9), # a = 2, b = 3, c = 4
|
||||
@ -2158,11 +2157,11 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
|
||||
[
|
||||
[
|
||||
PolyHarness(cum_name, "reduce_axis=poly",
|
||||
PolyHarness(cum_name, "reduce_axis_poly",
|
||||
lambda x: cum_func(x, axis=0),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness(cum_name, "reduce_axis=static",
|
||||
PolyHarness(cum_name, "reduce_axis_static",
|
||||
lambda x: cum_func(x, axis=1),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."])
|
||||
@ -3103,17 +3102,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
|
||||
# Some tests need the latest jaxlib
|
||||
need_new_jaxlib = []
|
||||
if jaxlib_version < (0, 4, 13):
|
||||
need_new_jaxlib.append("fft")
|
||||
elif jaxlib_version < (0, 4, 14):
|
||||
need_new_jaxlib.extend(("lu", "vmap_lu", "custom_linear_solve",
|
||||
"vmap_custom_linear_solve",
|
||||
"vmap_approx_top_k", "schur"))
|
||||
if harness.group_name in need_new_jaxlib:
|
||||
raise unittest.SkipTest("native lowering with shape polymorphism needs newer jaxlib")
|
||||
|
||||
if (jtu.device_under_test() in ["cpu", "gpu"] and
|
||||
harness.fullname in [
|
||||
"cumsum_reduce_axis_poly", "cumprod_reduce_axis_poly",
|
||||
@ -3121,7 +3109,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
"cumlogsumexp_reduce_axis_poly",
|
||||
"jnp_insert_insert_constant", "jnp_insert_insert_poly",
|
||||
"jnp_nonzero_size_constant", "jnp_nonzero_size_poly"]):
|
||||
# Need associative scan reductions on CPU and GPU
|
||||
# Need associative scan reductions on CPU and GPU. On TPU we use the
|
||||
# reduce_window HLO, but on CPU and GPU (with axis size >= 32) we use
|
||||
# a recursive associative scan that we cannot express with shape
|
||||
# polymorphism.
|
||||
raise unittest.SkipTest(
|
||||
"native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user