Export jnp.broadcast_shapes as user facing function

This commit is contained in:
Lukas Geiger 2021-04-30 19:32:51 +02:00
parent bd35379843
commit a613ce12a3
4 changed files with 29 additions and 1 deletions

View File

@ -87,6 +87,7 @@ Not every function in NumPy is implemented; contributions are welcome!
block
bool_
broadcast_arrays
broadcast_shapes
broadcast_to
can_cast
cbrt

View File

@ -1733,6 +1733,12 @@ def bincount(x, weights=None, minlength=0, *, length=None):
raise ValueError("shape of weights must match shape of x.")
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
@_wraps(getattr(np, "broadcast_shapes", None))
def broadcast_shapes(*shapes):
if not shapes:
return ()
shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes]
return lax.broadcast_shapes(*shapes)
def broadcast_arrays(*args):
"""Like Numpy's broadcast_arrays but doesn't return views."""

View File

@ -25,7 +25,7 @@ from jax._src.numpy.lax_numpy import (
arcsinh, arctan, arctan2, arctanh, argmax, argmin, argsort, argwhere, around,
array, array_equal, array_equiv, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d,
atleast_3d, average, bartlett, bfloat16, bincount, bitwise_and, bitwise_not,
bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays,
bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays, broadcast_shapes,
broadcast_to, can_cast, cbrt, cdouble, ceil, character, choose, clip, column_stack,
complex128, complex64, complex_, complexfloating, compress, concatenate,
conj, conjugate, convolve, copysign, corrcoef, correlate, cos, cosh,

View File

@ -4980,6 +4980,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": f"_{shapes}", "shapes": shapes, "broadcasted_shape": broadcasted_shape}
for shapes, broadcasted_shape in [
[[], ()],
[[()], ()],
[[(1, 3), (4, 3)], (4, 3)],
[[(3,), (2, 1, 3)], (2, 1, 3)],
[[(3,), (3, 3)], (3, 3)],
[[(1,), (3,)], (3,)],
[[(1,), 3], (3,)],
[[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
[[[1], [0, 1]], (0, 1)],
[[(1,), np.array([0, 1])], (0, 1)],
])
def testBroadcastShapes(self, shapes, broadcasted_shape):
# Test against np.broadcast_shapes once numpy 1.20 is minimum required version
np.testing.assert_equal(jnp.broadcast_shapes(*shapes), broadcasted_shape)
def testBroadcastToIssue1522(self):
self.assertRaisesRegex(
ValueError, "Incompatible shapes for broadcasting: .*",
@ -5279,6 +5297,9 @@ class NumpySignaturesTest(jtu.JaxTestCase):
mismatches = {}
for name, (jnp_fun, np_fun) in func_pairs.items():
# broadcast_shapes is not available in numpy < 1.20
if np.__version__ < "1.20" and name == "broadcast_shapes":
continue
# Some signatures have changed; skip for older numpy versions.
if np.__version__ < "1.19" and name in ['einsum_path', 'gradient', 'isscalar']:
continue