1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

Require arraylike input for several jax.numpy functions

PiperOrigin-RevId: 630532821
This commit is contained in:
Jake VanderPlas 2024-05-03 16:54:22 -07:00 committed by jax authors
parent 53208ffe27
commit e95173a4d3
2 changed files with 12 additions and 14 deletions

@ -68,6 +68,11 @@ Remember to align the itemized text with the first line of an item within a list
positional-only, following deprecation of the keywords in JAX v0.4.21.
* Non-array arguments to functions in {mod}`jax.lax.linalg` now must be
specified by keyword. Previously, this raised a DeprecationWarning.
* Array-like arguments are now required in several :func:`jax.numpy` APIs,
including {func}`~jax.numpy.apply_along_axis`,
{func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`,
{func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`,
{func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`.
* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.

@ -3759,8 +3759,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike,
def apply_along_axis(
func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs
) -> Array:
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("apply_along_axis", arr, emit_warning=True)
util.check_arraylike("apply_along_axis", arr)
num_dims = ndim(arr)
axis = _canonicalize_axis(axis, num_dims)
func = lambda arr: func1d(arr, *args, **kwargs)
@ -3774,8 +3773,7 @@ def apply_along_axis(
@util.implements(np.apply_over_axes)
def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,
axes: Sequence[int]) -> Array:
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("apply_over_axes", a, emit_warning=True)
util.check_arraylike("apply_over_axes", a)
a_arr = asarray(a)
for axis in axes:
b = func(a_arr, axis)
@ -4233,8 +4231,7 @@ def inner(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None,
preferred_element_type: DType | None = None,
) -> Array:
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("inner", a, b, emit_warning=True)
util.check_arraylike("inner", a, b)
if ndim(a) == 0 or ndim(b) == 0:
a = asarray(a, dtype=preferred_element_type)
b = asarray(b, dtype=preferred_element_type)
@ -4248,8 +4245,7 @@ def inner(
def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("outer", a, b, emit_warning=True)
util.check_arraylike("outer", a, b)
a, b = util.promote_dtypes(a, b)
return ravel(a)[:, None] * ravel(b)[None, :]
@ -4258,8 +4254,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
axis: int | None = None):
# TODO(jakevdp): NumPy 2.0 deprecates 2D inputs. Follow suit here.
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("cross", a, b, emit_warning=True)
util.check_arraylike("cross", a, b)
if axis is not None:
axisa = axis
axisb = axis
@ -4286,8 +4281,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
@util.implements(np.kron)
@jit
def kron(a: ArrayLike, b: ArrayLike) -> Array:
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("kron", a, b, emit_warning=True)
util.check_arraylike("kron", a, b)
a, b = util.promote_dtypes(a, b)
if ndim(a) < ndim(b):
a = expand_dims(a, range(ndim(b) - ndim(a)))
@ -4530,8 +4524,7 @@ def sort_complex(a: ArrayLike) -> Array:
@partial(jit, static_argnames=('axis',))
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
key_tuple = tuple(keys)
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("lexsort", *key_tuple, emit_warning=True)
util.check_arraylike("lexsort", *key_tuple)
key_arrays = tuple(asarray(k) for k in key_tuple)
if len(key_arrays) == 0:
raise TypeError("need sequence of keys with len > 0 in lexsort")