mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improved argument checking for lax.broadcast_in_dim
* Added checking that the output shape has higher or equal rank to input * Added checking that the broadcast_dims are sorted (required by XLA) * Relaxed check that operand dimension size can be 1 * Added lax.broadcast_in_dim docstring
This commit is contained in:
parent
c0c3a4a506
commit
5cf82c756e
@ -615,6 +615,11 @@ def broadcast(operand, sizes):
|
||||
return broadcast_in_dim(operand, tuple(sizes) + onp.shape(operand), dims)
|
||||
|
||||
def broadcast_in_dim(operand, shape, broadcast_dimensions):
|
||||
"""Wraps XLA's `BroadcastInDim
|
||||
<https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
|
||||
operator.
|
||||
"""
|
||||
shape = _broadcast_in_dim_shape_rule(operand, shape, broadcast_dimensions)
|
||||
if onp.ndim(operand) == len(shape) and not len(broadcast_dimensions):
|
||||
return operand
|
||||
return broadcast_in_dim_p.bind(
|
||||
@ -2355,25 +2360,31 @@ def _broadcast_in_dim_shape_rule(operand, shape, broadcast_dimensions):
|
||||
_check_shapelike('broadcast_in_dim', 'shape', shape)
|
||||
_check_shapelike('broadcast_in_dim', 'broadcast_dimensions',
|
||||
broadcast_dimensions)
|
||||
if any(x >= len(shape) for x in broadcast_dimensions):
|
||||
msg = ("broadcast_in_dim broadcast dimensions must be less than "
|
||||
"ndim(shape), got {} for shape {}.")
|
||||
raise ValueError(msg.format(broadcast_dimensions, shape))
|
||||
if operand.ndim != len(broadcast_dimensions):
|
||||
msg = ('broadcast_in_dim broadcast_dimensions must have length equal to '
|
||||
'operand ndim, got broadcast_dimensions {} for operand ndim {}.')
|
||||
'operand ndim; got broadcast_dimensions {} for operand ndim {}.')
|
||||
raise TypeError(msg.format(broadcast_dimensions, operand.ndim))
|
||||
if len(shape) < operand.ndim:
|
||||
msg = ('broadcast_in_dim target broadcast shape must have equal or higher rank '
|
||||
'to the operand shape; got operand ndim {} and target broadcast ndim {}.')
|
||||
raise TypeError(msg.format(operand.ndim, len(shape)))
|
||||
if not set(broadcast_dimensions).issubset(set(range(len(shape)))):
|
||||
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
|
||||
'dimensions, got {} for operand ndim {} and shape {}.')
|
||||
raise TypeError(msg.format(broadcast_dimensions, operand.ndim, shape))
|
||||
if any(operand.shape[i] != shape[broadcast_dimensions[i]]
|
||||
if any(operand.shape[i] != 1 and operand.shape[i] != shape[broadcast_dimensions[i]]
|
||||
for i in range(operand.ndim)):
|
||||
msg = ('broadcast_in_dim operand dimension sizes must equal their '
|
||||
'corresponding dimensions in the broadcasted-to shape; got '
|
||||
'operand of shape {}, target broadcast shape {}, '
|
||||
'broadcast_dimensions {} ')
|
||||
raise TypeError(msg.format(operand.shape, shape, broadcast_dimensions))
|
||||
msg = ('broadcast_in_dim operand dimension sizes must either be 1, or be '
|
||||
'equal to their corresponding dimensions in the target broadcast shape; '
|
||||
'got operand of shape {}, target broadcast shape {}, '
|
||||
'broadcast_dimensions {} ')
|
||||
raise TypeError(msg.format(operand.shape, shape, broadcast_dimensions))
|
||||
if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or
|
||||
tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))):
|
||||
msg = ('broadcast_in_dim broadcast_dimensions must be strictly increasing; '
|
||||
'got broadcast_dimensions {}')
|
||||
raise TypeError(msg.format(broadcast_dimensions))
|
||||
|
||||
return shape
|
||||
|
||||
def _broadcast_in_dim_transpose_rule(t, shape, broadcast_dimensions):
|
||||
|
@ -188,9 +188,10 @@ def broadcast(operand, sizes):
|
||||
return onp.broadcast_to(operand, sizes + onp.shape(operand))
|
||||
|
||||
def broadcast_in_dim(operand, shape, broadcast_dimensions):
|
||||
inshape = tuple(1 if i not in broadcast_dimensions else d
|
||||
for i, d in enumerate(shape))
|
||||
return onp.broadcast_to(onp.reshape(operand, inshape), shape)
|
||||
in_reshape = onp.ones(len(shape), dtype=onp.int32)
|
||||
for i, bd in enumerate(broadcast_dimensions):
|
||||
in_reshape[bd] = operand.shape[i]
|
||||
return onp.broadcast_to(onp.reshape(operand, in_reshape), shape)
|
||||
|
||||
sum = onp.sum
|
||||
|
||||
|
@ -822,6 +822,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
([2], [2, 2], [1]),
|
||||
([2], [2, 3], [0]),
|
||||
([], [2, 3], []),
|
||||
([1], [2, 3], [1]),
|
||||
]
|
||||
for dtype in default_dtypes
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@ -831,16 +832,29 @@ class LaxTest(jtu.JaxTestCase):
|
||||
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
|
||||
def testBroadcastInDimShapeCheck(self):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
||||
jtu.format_shape_dtype_string(inshape, onp.float32),
|
||||
outshape, broadcast_dimensions),
|
||||
"inshape": inshape, "outshape": outshape,
|
||||
"broadcast_dimensions": broadcast_dimensions, "err_msg": err_msg}
|
||||
for inshape, outshape, broadcast_dimensions, err_msg in [
|
||||
([2], [2, 2], [0, 1], ('broadcast_dimensions must have length equal to '
|
||||
'operand ndim')),
|
||||
([2, 2], [2], [0, 1], ('target broadcast shape must have equal or higher rank '
|
||||
'to the operand shape')),
|
||||
([2], [2, 3], [2], ('broadcast_in_dim broadcast_dimensions must be a subset of output '
|
||||
'dimensions')),
|
||||
([2], [3], [0], ('operand dimension sizes must either be 1, or be '
|
||||
'equal to their corresponding dimensions in the target broadcast shape')),
|
||||
([2, 2], [2, 2], [1, 0], ('broadcast_dimensions must be strictly increasing')),
|
||||
]))
|
||||
def testBroadcastInDimShapeCheck(self, inshape, outshape, broadcast_dimensions, err_msg):
|
||||
rng = jtu.rand_default()
|
||||
x = rng((6, 7), onp.float32)
|
||||
def op(x):
|
||||
lax.broadcast_in_dim(x, broadcast_dimensions=(1, 2), shape=(3, 4, 5))
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
("broadcast_in_dim operand dimension sizes must equal their "
|
||||
"corresponding dimensions in the broadcasted-to shape;*"),
|
||||
lambda: op(x))
|
||||
x = rng(inshape, onp.float32)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
lax.broadcast_in_dim(x, shape=outshape, broadcast_dimensions=broadcast_dimensions)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
||||
@ -853,6 +867,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
([2], [2, 2], [1]),
|
||||
([2], [2, 3], [0]),
|
||||
([], [2, 3], []),
|
||||
([1], [2, 3], [1]),
|
||||
]
|
||||
for dtype in default_dtypes
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user