mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #9576 from nicholasjng:broadcast-validation
PiperOrigin-RevId: 432531230
This commit is contained in:
commit
2a3f936ffa
@ -83,9 +83,23 @@ Shape = core.Shape
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
def _validate_shapes(shapes: Sequence[Shape]):
|
||||
def _check_static_shape(shape: Shape):
|
||||
checked = canonicalize_shape(shape)
|
||||
if not all(idx >= 0 for idx in checked):
|
||||
msg = f"Only non-negative indices are allowed when broadcasting" \
|
||||
f" static shapes, but got shape {shape!r}."
|
||||
raise TypeError(msg)
|
||||
|
||||
assert shapes
|
||||
if config.jax_dynamic_shapes:
|
||||
# pass dynamic shapes through unchecked
|
||||
return
|
||||
else:
|
||||
_ = tuple(map(_check_static_shape, shapes))
|
||||
|
||||
def _try_broadcast_shapes(
|
||||
shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]:
|
||||
assert shapes
|
||||
if len(shapes) == 1: return shapes[0]
|
||||
rank, *others = {len(shape) for shape in shapes}
|
||||
if others: return None # must have consistent rank
|
||||
@ -113,6 +127,7 @@ def _broadcast_shapes_cached(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
return _broadcast_shapes_uncached(*shapes)
|
||||
|
||||
def _broadcast_shapes_uncached(*shapes):
|
||||
_validate_shapes(shapes)
|
||||
fst, *rst = shapes
|
||||
if not rst: return fst
|
||||
|
||||
|
@ -723,6 +723,15 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
out_shape = lax.broadcast_shapes(shape1, shape2)
|
||||
self.assertTrue(all(type(s) is int for s in out_shape))
|
||||
|
||||
def testBroadcastShapesFaultyInputs(self):
|
||||
err_shape1, err_shape2 = (-1,), "hello"
|
||||
# negative inputs should fail while informing about illegal negative indices...
|
||||
with self.assertRaisesRegex(TypeError, "Only non-negative indices are allowed.*"):
|
||||
lax.broadcast_shapes(err_shape1)
|
||||
# ... while non-integers should error earlier, in the canonicalize_shape machinery.
|
||||
with self.assertRaisesRegex(TypeError, "Shapes must be 1D sequences.*"):
|
||||
lax.broadcast_shapes(err_shape2) # pytype: disable=wrong-arg-types
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_k={}_bdims={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), k, bdims),
|
||||
|
Loading…
x
Reference in New Issue
Block a user