mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve handling of dynamic shapes in jax native serialization
PiperOrigin-RevId: 518634912
This commit is contained in:
parent
f106d45371
commit
022b47fd91
@ -36,6 +36,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lib import xla_client
|
||||
import numpy as np
|
||||
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
@ -2634,18 +2635,25 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
# Set of harness.group_name that are unsupported in serialization
|
||||
require_stablehlo_feature_support = {
|
||||
# Tan and TopK require additional support for dynamic shape lowering
|
||||
# Tan (b/274462307) and TopK (openxla/stablehlo#1255) require support.
|
||||
"vmap_tan", "vmap_top_k",
|
||||
# Filter CHLO decompositions that produce shape dialect ops
|
||||
"vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh",
|
||||
"vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf",
|
||||
"vmap_erfc", "vmap_lgamma", "vmap_nextafter",
|
||||
"vmap_nextafter_broadcasting", "vmap_sinh",
|
||||
# Crash due to openxla/stablehlo#1328
|
||||
"vmap_random_randint", "vmap_random_uniform"
|
||||
}
|
||||
if harness.group_name in require_stablehlo_feature_support:
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
|
||||
# API version 47 supports CHLO ops that decompose into shape dialect ops
|
||||
if xla_client.mlir_api_version < 47:
|
||||
require_stablehlo_feature_support_shape_dialect = {
|
||||
"vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh",
|
||||
"vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf",
|
||||
"vmap_erfc", "vmap_lgamma", "vmap_nextafter",
|
||||
"vmap_nextafter_broadcasting", "vmap_sinh"
|
||||
}
|
||||
if harness.group_name in require_stablehlo_feature_support_shape_dialect:
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
if (jtu.device_under_test() == "tpu" and
|
||||
harness.fullname in [
|
||||
"jnp.cumsum_reduce_axis=poly",
|
||||
|
Loading…
x
Reference in New Issue
Block a user