mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Call _check_arraylike on inputs to broadcast_to and broadcast_arrays
This commit is contained in:
parent
43e65b3dbb
commit
8c57ae2a19
@ -18,6 +18,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
instead, e.g., `x.at[idx].set(y)`.
|
||||
* Moved `jax.experimental.ann.approx_*_k` into `jax.lax`. These functions are
|
||||
optimized alternatives to `jax.lax.top_k`.
|
||||
* {func}`jax.numpy.broadcast_arrays` and {func}`jax.numpy.broadcast_to` now require scalar
|
||||
or array-like inputs, and will fail if they are passed lists (part of {jax-issue}`#7737`).
|
||||
|
||||
|
||||
## jaxlib 0.3.1 (Unreleased)
|
||||
|
@ -1963,6 +1963,8 @@ def broadcast_shapes(*shapes):
|
||||
@partial(jit, inline=True)
|
||||
def broadcast_arrays(*args):
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
# Avoid calling _check_arraylike() here to allow passing through objects
|
||||
# like PRNGKeyArray which are specially handled in broadcast_to() below.
|
||||
shapes = [shape(arg) for arg in args]
|
||||
if not shapes or _all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
|
||||
# TODO(mattjj): remove the array(arg) here
|
||||
@ -1978,6 +1980,7 @@ The JAX version does not necessarily return a view of the input.
|
||||
def broadcast_to(arr, shape):
|
||||
if hasattr(arr, "broadcast_to"):
|
||||
return arr.broadcast_to(shape)
|
||||
_check_arraylike("broadcast_to", arr)
|
||||
arr = arr if isinstance(arr, ndarray) else array(arr)
|
||||
if not isinstance(shape, tuple) and ndim(shape) == 0:
|
||||
shape = (shape,)
|
||||
|
Loading…
x
Reference in New Issue
Block a user