Improve handling of dynamic shapes in jax native serialization

PiperOrigin-RevId: 518634912
This commit is contained in:
Kevin Gleason 2023-03-22 12:00:29 -07:00 committed by jax authors
parent f106d45371
commit 022b47fd91

View File

@ -36,6 +36,7 @@ from jax._src import test_util as jtu
from jax._src import util
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib import xla_client
import numpy as np
from jax.experimental.jax2tf.tests import tf_test_util
@ -2634,18 +2635,25 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# Set of harness.group_name that are unsupported in serialization
require_stablehlo_feature_support = {
# Tan and TopK require additional support for dynamic shape lowering
# Tan (b/274462307) and TopK (openxla/stablehlo#1255) require support.
"vmap_tan", "vmap_top_k",
# Filter CHLO decompositions that produce shape dialect ops
"vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh",
"vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf",
"vmap_erfc", "vmap_lgamma", "vmap_nextafter",
"vmap_nextafter_broadcasting", "vmap_sinh",
# Crash due to openxla/stablehlo#1328
"vmap_random_randint", "vmap_random_uniform"
}
if harness.group_name in require_stablehlo_feature_support:
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")
# API version 47 supports CHLO ops that decompose into shape dialect ops
if xla_client.mlir_api_version < 47:
require_stablehlo_feature_support_shape_dialect = {
"vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh",
"vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf",
"vmap_erfc", "vmap_lgamma", "vmap_nextafter",
"vmap_nextafter_broadcasting", "vmap_sinh"
}
if harness.group_name in require_stablehlo_feature_support_shape_dialect:
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")
if (jtu.device_under_test() == "tpu" and
harness.fullname in [
"jnp.cumsum_reduce_axis=poly",