mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15949 from gnecula:fix_poly
PiperOrigin-RevId: 530832168
This commit is contained in:
commit
48f551378a
@ -833,9 +833,11 @@ def _parse_spec(shape_spec: Union[str, PolyShape, None],
|
||||
Args:
|
||||
shape_spec: a shape polymorphic specification. None stands for "...".
|
||||
arg_shape: an actual shape, possibly containing unknown dimensions (None).
|
||||
We use `arg_shape` only to fill-in the placeholders `_` and `...` in
|
||||
We use `arg_shape` to fill-in the placeholders `_` and `...` in
|
||||
the `shape_spec`. The dimensions of `arg_shape` that are used for filling
|
||||
must be known (not `None`).
|
||||
must be known (not `None`). If a dimension in `arg_shape` is known and
|
||||
the corresponding dimension in `shape_spec` is a constant then they
|
||||
must be equal.
|
||||
|
||||
See the README.md for usage.
|
||||
"""
|
||||
@ -872,6 +874,12 @@ class _Parser:
|
||||
raise self.parse_err(tok,
|
||||
("unexpected placeholder for unknown dimension "
|
||||
f"in argument shape {self.arg_shape}"))
|
||||
arg_shape_dim = self.arg_shape[len(self.dimensions)]
|
||||
if core.is_constant_dim(expr) and arg_shape_dim is not None:
|
||||
if expr != arg_shape_dim:
|
||||
raise ValueError(
|
||||
f"polymorphic shape {self.shape_spec_repr} in axis {len(self.dimensions)} "
|
||||
f"must match the known dimension size in arg shape {self.arg_shape}")
|
||||
self.dimensions.append(expr)
|
||||
|
||||
def parse_err(self, tok: Optional[tokenize.TokenInfo], detail: str) -> Exception:
|
||||
|
@ -119,6 +119,20 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
|
||||
"syntax error in polymorphic shape"):
|
||||
shape_poly._parse_spec(shape_spec, (None,))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_{shape_spec=}",
|
||||
shape_spec=shape_spec, arg_shape=arg_shape)
|
||||
for shape_spec, arg_shape in [
|
||||
("3", (4,)),
|
||||
("b, 3", (None, 4)),
|
||||
])
|
||||
def test_parse_mismatch_error(self,
|
||||
shape_spec="3", arg_shape=(4,)):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"polymorphic shape .* in axis .* must match the known dimension size"):
|
||||
shape_poly._parse_spec(shape_spec, arg_shape)
|
||||
|
||||
|
||||
def test_dim_vars(self):
|
||||
a, b, a1 = shape_poly._parse_spec("a, b, a", (2, 3, 2))
|
||||
self.assertEqual(True, a == a)
|
||||
|
Loading…
x
Reference in New Issue
Block a user