mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Relax the limit on the number of error inputs for shape assertions
Lift the limit from 4 to 32 to follow the change in tf.XlaCallModule of this limit. Currently, the error message formatter needs at most 6 error message inputs.
This commit is contained in:
parent
7a5b9ff522
commit
26eb6c3f27
@ -55,6 +55,7 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src import tree_util
|
||||
@ -1269,9 +1270,12 @@ class ShapeConstraint:
|
||||
See shape_assertion.
|
||||
"""
|
||||
# There is currenly a limitation in the shape assertion checker that
|
||||
# it supports at most 4 error_message_inputs. We try to stay within the
|
||||
# it supports at most 32 error_message_inputs. We try to stay within the
|
||||
# limit, reusing a format specifier if possible.
|
||||
# TODO(necula): remove this limit
|
||||
if jaxlib_version <= (0, 4, 14):
|
||||
max_error_message_inputs = 4
|
||||
else:
|
||||
max_error_message_inputs = 32
|
||||
format_specifiers: dict[DimSize, str] = {}
|
||||
error_message_inputs: list[Any] = []
|
||||
error_message_strings: list[str] = []
|
||||
@ -1283,7 +1287,7 @@ class ShapeConstraint:
|
||||
if cached_spec is not None:
|
||||
error_message_strings.append(cached_spec)
|
||||
continue
|
||||
if len(error_message_inputs) >= 4:
|
||||
if len(error_message_inputs) >= max_error_message_inputs:
|
||||
error_message_strings.append("N/A")
|
||||
continue
|
||||
spec = "{" + str(len(error_message_inputs)) + "}"
|
||||
|
@ -25,12 +25,12 @@ from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.config import config
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
from jax.lib import xla_client as xc
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
@ -489,7 +489,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# call_exported error reporting.
|
||||
@jtu.parameterized_filterable(
|
||||
#one_containing="7, 2, 36",
|
||||
testcase_name=lambda kw: kw["shape"],
|
||||
testcase_name=lambda kw: kw["shape"], # assume "shape" is unique
|
||||
kwargs=[
|
||||
dict(shape=(8, 2, 9), # a = 2, b = 3, c = 4
|
||||
poly_spec="(a + 2*b, a, a + b + c)"),
|
||||
@ -519,7 +519,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification 'a + 2*b' (= 10). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b). "
|
||||
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
||||
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= N/A), . "
|
||||
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= 6), . "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
|
||||
)),
|
||||
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
|
||||
@ -537,6 +537,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
|
||||
return 0.
|
||||
|
||||
if shape == (8, 2, 6) and jaxlib_version <= (0, 4, 14):
|
||||
raise unittest.SkipTest("Test requires jaxlib >= 0.4.14")
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error is not None:
|
||||
|
@ -1074,7 +1074,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification 'a + 2*b' (= 10). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b). "
|
||||
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
||||
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= N/A), . "
|
||||
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= 6), . "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
|
||||
)),
|
||||
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
|
||||
@ -1092,6 +1092,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
|
||||
return 0.
|
||||
|
||||
if shape == (8, 2, 6) and jaxlib_version <= (0, 4, 14):
|
||||
raise unittest.SkipTest("Test requires jaxlib >= 0.4.14")
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user