mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Export jnp.broadcast_shapes
as user facing function
This commit is contained in:
parent
bd35379843
commit
a613ce12a3
@ -87,6 +87,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
block
|
||||
bool_
|
||||
broadcast_arrays
|
||||
broadcast_shapes
|
||||
broadcast_to
|
||||
can_cast
|
||||
cbrt
|
||||
|
@ -1733,6 +1733,12 @@ def bincount(x, weights=None, minlength=0, *, length=None):
|
||||
raise ValueError("shape of weights must match shape of x.")
|
||||
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
|
||||
|
||||
@_wraps(getattr(np, "broadcast_shapes", None))
|
||||
def broadcast_shapes(*shapes):
|
||||
if not shapes:
|
||||
return ()
|
||||
shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes]
|
||||
return lax.broadcast_shapes(*shapes)
|
||||
|
||||
def broadcast_arrays(*args):
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
|
@ -25,7 +25,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
arcsinh, arctan, arctan2, arctanh, argmax, argmin, argsort, argwhere, around,
|
||||
array, array_equal, array_equiv, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d,
|
||||
atleast_3d, average, bartlett, bfloat16, bincount, bitwise_and, bitwise_not,
|
||||
bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays,
|
||||
bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays, broadcast_shapes,
|
||||
broadcast_to, can_cast, cbrt, cdouble, ceil, character, choose, clip, column_stack,
|
||||
complex128, complex64, complex_, complexfloating, compress, concatenate,
|
||||
conj, conjugate, convolve, copysign, corrcoef, correlate, cos, cosh,
|
||||
|
@ -4980,6 +4980,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{shapes}", "shapes": shapes, "broadcasted_shape": broadcasted_shape}
|
||||
for shapes, broadcasted_shape in [
|
||||
[[], ()],
|
||||
[[()], ()],
|
||||
[[(1, 3), (4, 3)], (4, 3)],
|
||||
[[(3,), (2, 1, 3)], (2, 1, 3)],
|
||||
[[(3,), (3, 3)], (3, 3)],
|
||||
[[(1,), (3,)], (3,)],
|
||||
[[(1,), 3], (3,)],
|
||||
[[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
|
||||
[[[1], [0, 1]], (0, 1)],
|
||||
[[(1,), np.array([0, 1])], (0, 1)],
|
||||
])
|
||||
def testBroadcastShapes(self, shapes, broadcasted_shape):
|
||||
# Test against np.broadcast_shapes once numpy 1.20 is minimum required version
|
||||
np.testing.assert_equal(jnp.broadcast_shapes(*shapes), broadcasted_shape)
|
||||
|
||||
def testBroadcastToIssue1522(self):
|
||||
self.assertRaisesRegex(
|
||||
ValueError, "Incompatible shapes for broadcasting: .*",
|
||||
@ -5279,6 +5297,9 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
mismatches = {}
|
||||
|
||||
for name, (jnp_fun, np_fun) in func_pairs.items():
|
||||
# broadcast_shapes is not available in numpy < 1.20
|
||||
if np.__version__ < "1.20" and name == "broadcast_shapes":
|
||||
continue
|
||||
# Some signatures have changed; skip for older numpy versions.
|
||||
if np.__version__ < "1.19" and name in ['einsum_path', 'gradient', 'isscalar']:
|
||||
continue
|
||||
|
Loading…
x
Reference in New Issue
Block a user