Merge pull request #15949 from gnecula:fix_poly

PiperOrigin-RevId: 530832168
This commit is contained in:
jax authors 2023-05-10 00:51:37 -07:00
commit 48f551378a
2 changed files with 24 additions and 2 deletions

View File

@ -833,9 +833,11 @@ def _parse_spec(shape_spec: Union[str, PolyShape, None],
Args:
shape_spec: a shape polymorphic specification. None stands for "...".
arg_shape: an actual shape, possibly containing unknown dimensions (None).
We use `arg_shape` only to fill-in the placeholders `_` and `...` in
We use `arg_shape` to fill-in the placeholders `_` and `...` in
the `shape_spec`. The dimensions of `arg_shape` that are used for filling
must be known (not `None`).
must be known (not `None`). If a dimension in `arg_shape` is known and
the corresponding dimension in `shape_spec` is a constant then they
must be equal.
See the README.md for usage.
"""
@ -872,6 +874,12 @@ class _Parser:
raise self.parse_err(tok,
("unexpected placeholder for unknown dimension "
f"in argument shape {self.arg_shape}"))
arg_shape_dim = self.arg_shape[len(self.dimensions)]
if core.is_constant_dim(expr) and arg_shape_dim is not None:
if expr != arg_shape_dim:
raise ValueError(
f"polymorphic shape {self.shape_spec_repr} in axis {len(self.dimensions)} "
f"must match the known dimension size in arg shape {self.arg_shape}")
self.dimensions.append(expr)
def parse_err(self, tok: Optional[tokenize.TokenInfo], detail: str) -> Exception:

View File

@ -119,6 +119,20 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
"syntax error in polymorphic shape"):
shape_poly._parse_spec(shape_spec, (None,))
@parameterized.named_parameters(
dict(testcase_name=f"_{shape_spec=}",
shape_spec=shape_spec, arg_shape=arg_shape)
for shape_spec, arg_shape in [
("3", (4,)),
("b, 3", (None, 4)),
])
def test_parse_mismatch_error(self,
shape_spec="3", arg_shape=(4,)):
with self.assertRaisesRegex(ValueError,
"polymorphic shape .* in axis .* must match the known dimension size"):
shape_poly._parse_spec(shape_spec, arg_shape)
def test_dim_vars(self):
a, b, a1 = shape_poly._parse_spec("a, b, a", (2, 3, 2))
self.assertEqual(True, a == a)