Improved argument checking for lax.broadcast_in_dim

* Added checking that the output shape has higher or equal rank to input
* Added checking that the broadcast_dims are sorted (required by XLA)
* Relaxed check that operand dimension size can be 1
* Added lax.broadcast_in_dim docstring
This commit is contained in:
George Necula 2020-03-16 09:54:58 +01:00
parent c0c3a4a506
commit 5cf82c756e
3 changed files with 50 additions and 23 deletions

View File

@ -615,6 +615,11 @@ def broadcast(operand, sizes):
return broadcast_in_dim(operand, tuple(sizes) + onp.shape(operand), dims)
def broadcast_in_dim(operand, shape, broadcast_dimensions):
"""Wraps XLA's `BroadcastInDim
<https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
operator.
"""
shape = _broadcast_in_dim_shape_rule(operand, shape, broadcast_dimensions)
if onp.ndim(operand) == len(shape) and not len(broadcast_dimensions):
return operand
return broadcast_in_dim_p.bind(
@ -2355,25 +2360,31 @@ def _broadcast_in_dim_shape_rule(operand, shape, broadcast_dimensions):
_check_shapelike('broadcast_in_dim', 'shape', shape)
_check_shapelike('broadcast_in_dim', 'broadcast_dimensions',
broadcast_dimensions)
if any(x >= len(shape) for x in broadcast_dimensions):
msg = ("broadcast_in_dim broadcast dimensions must be less than "
"ndim(shape), got {} for shape {}.")
raise ValueError(msg.format(broadcast_dimensions, shape))
if operand.ndim != len(broadcast_dimensions):
msg = ('broadcast_in_dim broadcast_dimensions must have length equal to '
'operand ndim, got broadcast_dimensions {} for operand ndim {}.')
'operand ndim; got broadcast_dimensions {} for operand ndim {}.')
raise TypeError(msg.format(broadcast_dimensions, operand.ndim))
if len(shape) < operand.ndim:
msg = ('broadcast_in_dim target broadcast shape must have equal or higher rank '
'to the operand shape; got operand ndim {} and target broadcast ndim {}.')
raise TypeError(msg.format(operand.ndim, len(shape)))
if not set(broadcast_dimensions).issubset(set(range(len(shape)))):
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
'dimensions, got {} for operand ndim {} and shape {}.')
raise TypeError(msg.format(broadcast_dimensions, operand.ndim, shape))
if any(operand.shape[i] != shape[broadcast_dimensions[i]]
if any(operand.shape[i] != 1 and operand.shape[i] != shape[broadcast_dimensions[i]]
for i in range(operand.ndim)):
msg = ('broadcast_in_dim operand dimension sizes must equal their '
'corresponding dimensions in the broadcasted-to shape; got '
'operand of shape {}, target broadcast shape {}, '
'broadcast_dimensions {} ')
raise TypeError(msg.format(operand.shape, shape, broadcast_dimensions))
msg = ('broadcast_in_dim operand dimension sizes must either be 1, or be '
'equal to their corresponding dimensions in the target broadcast shape; '
'got operand of shape {}, target broadcast shape {}, '
'broadcast_dimensions {} ')
raise TypeError(msg.format(operand.shape, shape, broadcast_dimensions))
if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or
tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))):
msg = ('broadcast_in_dim broadcast_dimensions must be strictly increasing; '
'got broadcast_dimensions {}')
raise TypeError(msg.format(broadcast_dimensions))
return shape
def _broadcast_in_dim_transpose_rule(t, shape, broadcast_dimensions):

View File

@ -188,9 +188,10 @@ def broadcast(operand, sizes):
return onp.broadcast_to(operand, sizes + onp.shape(operand))
def broadcast_in_dim(operand, shape, broadcast_dimensions):
inshape = tuple(1 if i not in broadcast_dimensions else d
for i, d in enumerate(shape))
return onp.broadcast_to(onp.reshape(operand, inshape), shape)
in_reshape = onp.ones(len(shape), dtype=onp.int32)
for i, bd in enumerate(broadcast_dimensions):
in_reshape[bd] = operand.shape[i]
return onp.broadcast_to(onp.reshape(operand, in_reshape), shape)
sum = onp.sum

View File

@ -822,6 +822,7 @@ class LaxTest(jtu.JaxTestCase):
([2], [2, 2], [1]),
([2], [2, 3], [0]),
([], [2, 3], []),
([1], [2, 3], [1]),
]
for dtype in default_dtypes
for rng_factory in [jtu.rand_default]))
@ -831,16 +832,29 @@ class LaxTest(jtu.JaxTestCase):
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
def testBroadcastInDimShapeCheck(self):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
jtu.format_shape_dtype_string(inshape, onp.float32),
outshape, broadcast_dimensions),
"inshape": inshape, "outshape": outshape,
"broadcast_dimensions": broadcast_dimensions, "err_msg": err_msg}
for inshape, outshape, broadcast_dimensions, err_msg in [
([2], [2, 2], [0, 1], ('broadcast_dimensions must have length equal to '
'operand ndim')),
([2, 2], [2], [0, 1], ('target broadcast shape must have equal or higher rank '
'to the operand shape')),
([2], [2, 3], [2], ('broadcast_in_dim broadcast_dimensions must be a subset of output '
'dimensions')),
([2], [3], [0], ('operand dimension sizes must either be 1, or be '
'equal to their corresponding dimensions in the target broadcast shape')),
([2, 2], [2, 2], [1, 0], ('broadcast_dimensions must be strictly increasing')),
]))
def testBroadcastInDimShapeCheck(self, inshape, outshape, broadcast_dimensions, err_msg):
rng = jtu.rand_default()
x = rng((6, 7), onp.float32)
def op(x):
lax.broadcast_in_dim(x, broadcast_dimensions=(1, 2), shape=(3, 4, 5))
self.assertRaisesRegex(
TypeError,
("broadcast_in_dim operand dimension sizes must equal their "
"corresponding dimensions in the broadcasted-to shape;*"),
lambda: op(x))
x = rng(inshape, onp.float32)
with self.assertRaisesRegex(TypeError, err_msg):
lax.broadcast_in_dim(x, shape=outshape, broadcast_dimensions=broadcast_dimensions)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
@ -853,6 +867,7 @@ class LaxTest(jtu.JaxTestCase):
([2], [2, 2], [1]),
([2], [2, 3], [0]),
([], [2, 3], []),
([1], [2, 3], [1]),
]
for dtype in default_dtypes
for rng_factory in [jtu.rand_default]))