Merge pull request #9576 from nicholasjng:broadcast-validation

PiperOrigin-RevId: 432531230
This commit is contained in:
jax authors 2022-03-04 14:21:17 -08:00
commit 2a3f936ffa
2 changed files with 25 additions and 1 deletions

View File

@ -83,9 +83,23 @@ Shape = core.Shape
T = TypeVar("T")
def _validate_shapes(shapes: Sequence[Shape]):
def _check_static_shape(shape: Shape):
checked = canonicalize_shape(shape)
if not all(idx >= 0 for idx in checked):
msg = f"Only non-negative indices are allowed when broadcasting" \
f" static shapes, but got shape {shape!r}."
raise TypeError(msg)
assert shapes
if config.jax_dynamic_shapes:
# pass dynamic shapes through unchecked
return
else:
_ = tuple(map(_check_static_shape, shapes))
def _try_broadcast_shapes(
shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]:
assert shapes
if len(shapes) == 1: return shapes[0]
rank, *others = {len(shape) for shape in shapes}
if others: return None # must have consistent rank
@ -113,6 +127,7 @@ def _broadcast_shapes_cached(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
return _broadcast_shapes_uncached(*shapes)
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst

View File

@ -723,6 +723,15 @@ class LaxVmapTest(jtu.JaxTestCase):
out_shape = lax.broadcast_shapes(shape1, shape2)
self.assertTrue(all(type(s) is int for s in out_shape))
def testBroadcastShapesFaultyInputs(self):
err_shape1, err_shape2 = (-1,), "hello"
# negative inputs should fail while informing about illegal negative indices...
with self.assertRaisesRegex(TypeError, "Only non-negative indices are allowed.*"):
lax.broadcast_shapes(err_shape1)
# ... while non-integers should error earlier, in the canonicalize_shape machinery.
with self.assertRaisesRegex(TypeError, "Shapes must be 1D sequences.*"):
lax.broadcast_shapes(err_shape2) # pytype: disable=wrong-arg-types
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}_bdims={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, bdims),