Validate lax.broadcast_shape inputs before control flow execution

This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.

In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
This commit is contained in:
Nicholas Junge 2022-02-15 15:03:33 +01:00
parent 7d02949d24
commit 56546d3e73
2 changed files with 25 additions and 1 deletions

View File

@ -83,9 +83,23 @@ Shape = core.Shape
T = TypeVar("T")
def _validate_shapes(shapes: Sequence[Shape]):
def _check_static_shape(shape: Shape):
checked = canonicalize_shape(shape)
if not all(idx >= 0 for idx in checked):
msg = f"Only non-negative indices are allowed when broadcasting" \
f" static shapes, but got shape {shape!r}."
raise TypeError(msg)
assert shapes
if config.jax_dynamic_shapes:
# pass dynamic shapes through unchecked
return
else:
_ = tuple(map(_check_static_shape, shapes))
def _try_broadcast_shapes(
shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]:
assert shapes
if len(shapes) == 1: return shapes[0]
rank, *others = {len(shape) for shape in shapes}
if others: return None # must have consistent rank
@ -113,6 +127,7 @@ def _broadcast_shapes_cached(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
return _broadcast_shapes_uncached(*shapes)
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst

View File

@ -723,6 +723,15 @@ class LaxVmapTest(jtu.JaxTestCase):
out_shape = lax.broadcast_shapes(shape1, shape2)
self.assertTrue(all(type(s) is int for s in out_shape))
def testBroadcastShapesFaultyInputs(self):
err_shape1, err_shape2 = (-1,), "hello"
# negative inputs should fail while informing about illegal negative indices...
with self.assertRaisesRegex(TypeError, "Only non-negative indices are allowed.*"):
lax.broadcast_shapes(err_shape1)
# ... while non-integers should error earlier, in the canonicalize_shape machinery.
with self.assertRaisesRegex(TypeError, "Shapes must be 1D sequences.*"):
lax.broadcast_shapes(err_shape2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}_bdims={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, bdims),