[shape_poly] Relax the limit on the number of error inputs for shape assertions

Lift the limit from 4 to 32 to follow the change in tf.XlaCallModule of
this limit. Currently, the error message formatter needs at most 6
error message inputs.
This commit is contained in:
George Necula 2023-08-17 11:19:04 +02:00
parent 7a5b9ff522
commit 26eb6c3f27
3 changed files with 15 additions and 7 deletions

View File

@ -55,6 +55,7 @@ from jax._src import core
from jax._src import dtypes
from jax._src import effects
from jax._src.lax import lax
from jax._src.lib import version as jaxlib_version
from jax._src.interpreters import mlir
from jax._src.numpy import lax_numpy
from jax._src import tree_util
@ -1269,9 +1270,12 @@ class ShapeConstraint:
See shape_assertion.
"""
# There is currenly a limitation in the shape assertion checker that
# it supports at most 4 error_message_inputs. We try to stay within the
# it supports at most 32 error_message_inputs. We try to stay within the
# limit, reusing a format specifier if possible.
# TODO(necula): remove this limit
if jaxlib_version <= (0, 4, 14):
max_error_message_inputs = 4
else:
max_error_message_inputs = 32
format_specifiers: dict[DimSize, str] = {}
error_message_inputs: list[Any] = []
error_message_strings: list[str] = []
@ -1283,7 +1287,7 @@ class ShapeConstraint:
if cached_spec is not None:
error_message_strings.append(cached_spec)
continue
if len(error_message_inputs) >= 4:
if len(error_message_inputs) >= max_error_message_inputs:
error_message_strings.append("N/A")
continue
spec = "{" + str(len(error_message_inputs)) + "}"

View File

@ -25,12 +25,12 @@ from jax import numpy as jnp
from jax import tree_util
from jax.config import config
from jax.experimental.jax2tf import jax_export
from jax.lib import xla_client as xc
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir.dialects import hlo
@ -489,7 +489,7 @@ class JaxExportTest(jtu.JaxTestCase):
# call_exported error reporting.
@jtu.parameterized_filterable(
#one_containing="7, 2, 36",
testcase_name=lambda kw: kw["shape"],
testcase_name=lambda kw: kw["shape"], # assume "shape" is unique
kwargs=[
dict(shape=(8, 2, 9), # a = 2, b = 3, c = 4
poly_spec="(a + 2*b, a, a + b + c)"),
@ -519,7 +519,7 @@ class JaxExportTest(jtu.JaxTestCase):
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification 'a + 2*b' (= 10). "
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= N/A), . "
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= 6), . "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
)),
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
@ -537,6 +537,8 @@ class JaxExportTest(jtu.JaxTestCase):
def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
return 0.
if shape == (8, 2, 6) and jaxlib_version <= (0, 4, 14):
raise unittest.SkipTest("Test requires jaxlib >= 0.4.14")
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
with contextlib.ExitStack() as stack:
if expect_error is not None:

View File

@ -1074,7 +1074,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification 'a + 2*b' (= 10). "
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= N/A), . "
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= 6), . "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
)),
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
@ -1092,6 +1092,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
return 0.
if shape == (8, 2, 6) and jaxlib_version <= (0, 4, 14):
raise unittest.SkipTest("Test requires jaxlib >= 0.4.14")
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
with contextlib.ExitStack() as stack:
if expect_error is not None: