mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Require arraylike input for several jax.numpy functions
PiperOrigin-RevId: 630532821
This commit is contained in:
parent
53208ffe27
commit
e95173a4d3
@ -68,6 +68,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
positional-only, following deprecation of the keywords in JAX v0.4.21.
|
||||
* Non-array arguments to functions in {mod}`jax.lax.linalg` now must be
|
||||
specified by keyword. Previously, this raised a DeprecationWarning.
|
||||
* Array-like arguments are now required in several :func:`jax.numpy` APIs,
|
||||
including {func}`~jax.numpy.apply_along_axis`,
|
||||
{func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`,
|
||||
{func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`,
|
||||
{func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`.
|
||||
|
||||
* Bug fixes
|
||||
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
|
||||
|
@ -3759,8 +3759,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike,
|
||||
def apply_along_axis(
|
||||
func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs
|
||||
) -> Array:
|
||||
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
|
||||
util.check_arraylike("apply_along_axis", arr, emit_warning=True)
|
||||
util.check_arraylike("apply_along_axis", arr)
|
||||
num_dims = ndim(arr)
|
||||
axis = _canonicalize_axis(axis, num_dims)
|
||||
func = lambda arr: func1d(arr, *args, **kwargs)
|
||||
@ -3774,8 +3773,7 @@ def apply_along_axis(
|
||||
@util.implements(np.apply_over_axes)
|
||||
def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,
|
||||
axes: Sequence[int]) -> Array:
|
||||
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
|
||||
util.check_arraylike("apply_over_axes", a, emit_warning=True)
|
||||
util.check_arraylike("apply_over_axes", a)
|
||||
a_arr = asarray(a)
|
||||
for axis in axes:
|
||||
b = func(a_arr, axis)
|
||||
@ -4233,8 +4231,7 @@ def inner(
|
||||
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None,
|
||||
preferred_element_type: DType | None = None,
|
||||
) -> Array:
|
||||
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
|
||||
util.check_arraylike("inner", a, b, emit_warning=True)
|
||||
util.check_arraylike("inner", a, b)
|
||||
if ndim(a) == 0 or ndim(b) == 0:
|
||||
a = asarray(a, dtype=preferred_element_type)
|
||||
b = asarray(b, dtype=preferred_element_type)
|
||||
@ -4248,8 +4245,7 @@ def inner(
|
||||
def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
|
||||
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
|
||||
util.check_arraylike("outer", a, b, emit_warning=True)
|
||||
util.check_arraylike("outer", a, b)
|
||||
a, b = util.promote_dtypes(a, b)
|
||||
return ravel(a)[:, None] * ravel(b)[None, :]
|
||||
|
||||
@ -4258,8 +4254,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
|
||||
def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
|
||||
axis: int | None = None):
|
||||
# TODO(jakevdp): NumPy 2.0 deprecates 2D inputs. Follow suit here.
|
||||
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
|
||||
util.check_arraylike("cross", a, b, emit_warning=True)
|
||||
util.check_arraylike("cross", a, b)
|
||||
if axis is not None:
|
||||
axisa = axis
|
||||
axisb = axis
|
||||
@ -4286,8 +4281,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
|
||||
@util.implements(np.kron)
|
||||
@jit
|
||||
def kron(a: ArrayLike, b: ArrayLike) -> Array:
|
||||
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
|
||||
util.check_arraylike("kron", a, b, emit_warning=True)
|
||||
util.check_arraylike("kron", a, b)
|
||||
a, b = util.promote_dtypes(a, b)
|
||||
if ndim(a) < ndim(b):
|
||||
a = expand_dims(a, range(ndim(b) - ndim(a)))
|
||||
@ -4530,8 +4524,7 @@ def sort_complex(a: ArrayLike) -> Array:
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
|
||||
key_tuple = tuple(keys)
|
||||
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
|
||||
util.check_arraylike("lexsort", *key_tuple, emit_warning=True)
|
||||
util.check_arraylike("lexsort", *key_tuple)
|
||||
key_arrays = tuple(asarray(k) for k in key_tuple)
|
||||
if len(key_arrays) == 0:
|
||||
raise TypeError("need sequence of keys with len > 0 in lexsort")
|
||||
|
Loading…
x
Reference in New Issue
Block a user