1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Call _check_arraylike on inputs to broadcast_to and broadcast_arrays

This commit is contained in:
Jake VanderPlas 2022-03-02 15:23:34 -08:00
parent 43e65b3dbb
commit 8c57ae2a19
2 changed files with 5 additions and 0 deletions

@ -18,6 +18,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
instead, e.g., `x.at[idx].set(y)`.
* Moved `jax.experimental.ann.approx_*_k` into `jax.lax`. These functions are
optimized alternatives to `jax.lax.top_k`.
* {func}`jax.numpy.broadcast_arrays` and {func}`jax.numpy.broadcast_to` now require scalar
or array-like inputs, and will fail if they are passed lists (part of {jax-issue}`#7737`).
## jaxlib 0.3.1 (Unreleased)

@ -1963,6 +1963,8 @@ def broadcast_shapes(*shapes):
@partial(jit, inline=True)
def broadcast_arrays(*args):
"""Like Numpy's broadcast_arrays but doesn't return views."""
# Avoid calling _check_arraylike() here to allow passing through objects
# like PRNGKeyArray which are specially handled in broadcast_to() below.
shapes = [shape(arg) for arg in args]
if not shapes or _all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
# TODO(mattjj): remove the array(arg) here
@ -1978,6 +1980,7 @@ The JAX version does not necessarily return a view of the input.
def broadcast_to(arr, shape):
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape)
_check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, ndarray) else array(arr)
if not isinstance(shape, tuple) and ndim(shape) == 0:
shape = (shape,)