mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Validate lax.broadcast_shape
inputs before control flow execution
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that 1) is a sequence (i.e., implements for..in iteration) of integers and 2) said integers are all non-negative. In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
This commit is contained in:
parent
7d02949d24
commit
56546d3e73
@ -83,9 +83,23 @@ Shape = core.Shape
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
def _validate_shapes(shapes: Sequence[Shape]):
|
||||
def _check_static_shape(shape: Shape):
|
||||
checked = canonicalize_shape(shape)
|
||||
if not all(idx >= 0 for idx in checked):
|
||||
msg = f"Only non-negative indices are allowed when broadcasting" \
|
||||
f" static shapes, but got shape {shape!r}."
|
||||
raise TypeError(msg)
|
||||
|
||||
assert shapes
|
||||
if config.jax_dynamic_shapes:
|
||||
# pass dynamic shapes through unchecked
|
||||
return
|
||||
else:
|
||||
_ = tuple(map(_check_static_shape, shapes))
|
||||
|
||||
def _try_broadcast_shapes(
|
||||
shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]:
|
||||
assert shapes
|
||||
if len(shapes) == 1: return shapes[0]
|
||||
rank, *others = {len(shape) for shape in shapes}
|
||||
if others: return None # must have consistent rank
|
||||
@ -113,6 +127,7 @@ def _broadcast_shapes_cached(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
return _broadcast_shapes_uncached(*shapes)
|
||||
|
||||
def _broadcast_shapes_uncached(*shapes):
|
||||
_validate_shapes(shapes)
|
||||
fst, *rst = shapes
|
||||
if not rst: return fst
|
||||
|
||||
|
@ -723,6 +723,15 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
out_shape = lax.broadcast_shapes(shape1, shape2)
|
||||
self.assertTrue(all(type(s) is int for s in out_shape))
|
||||
|
||||
def testBroadcastShapesFaultyInputs(self):
|
||||
err_shape1, err_shape2 = (-1,), "hello"
|
||||
# negative inputs should fail while informing about illegal negative indices...
|
||||
with self.assertRaisesRegex(TypeError, "Only non-negative indices are allowed.*"):
|
||||
lax.broadcast_shapes(err_shape1)
|
||||
# ... while non-integers should error earlier, in the canonicalize_shape machinery.
|
||||
with self.assertRaisesRegex(TypeError, "Shapes must be 1D sequences.*"):
|
||||
lax.broadcast_shapes(err_shape2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_k={}_bdims={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), k, bdims),
|
||||
|
Loading…
x
Reference in New Issue
Block a user