internal: move more NumPy APIs to ensure_arraylike

This commit is contained in:
Jake VanderPlas 2025-01-23 08:48:13 -08:00
parent 6b95ad0a53
commit 23c1d62910
2 changed files with 25 additions and 28 deletions

View File

@ -22,7 +22,7 @@ from jax import dtypes
from jax import lax
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
from jax._src.numpy.util import ensure_arraylike, promote_dtypes_inexact
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.sharding import Sharding
@ -49,8 +49,7 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
s: Shape | None, axes: Sequence[int] | None,
norm: str | None) -> Array:
full_name = f"jax.numpy.fft.{func_name}"
check_arraylike(full_name, a)
arr = jnp.asarray(a)
arr = ensure_arraylike(full_name, a)
if s is not None:
s = tuple(map(operator.index, s))
@ -1287,8 +1286,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
>>> jnp.fft.ifftshift(shifted_freq)
Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32)
"""
check_arraylike("fftshift", x)
x = jnp.asarray(x)
x = ensure_arraylike("fftshift", x)
shift: int | Sequence[int]
if axes is None:
axes = tuple(range(x.ndim))
@ -1337,8 +1335,7 @@ def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
>>> jnp.fft.ifftshift(shifted_freq)
Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32)
"""
check_arraylike("ifftshift", x)
x = jnp.asarray(x)
x = ensure_arraylike("ifftshift", x)
shift: int | Sequence[int]
if axes is None:
axes = tuple(range(x.ndim))

View File

@ -23,15 +23,16 @@ from jax import jit
from jax import lax
from jax._src import dtypes
from jax._src import core
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import (
arange, argmin, array, asarray, atleast_1d, concatenate, convolve,
arange, argmin, array, atleast_1d, concatenate, convolve,
diag, dot, finfo, full, ones, outer, roll, trim_zeros,
trim_zeros_tol, vander, zeros)
from jax._src.numpy.ufuncs import maximum, true_divide, sqrt
from jax._src.numpy.reductions import all
from jax._src.numpy import linalg
from jax._src.numpy.util import (
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where)
ensure_arraylike, promote_dtypes, promote_dtypes_inexact, _where)
from jax._src.typing import Array, ArrayLike
from jax._src.util import set_module
@ -102,8 +103,8 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
>>> jnp.roots(coeffs, strip_zeros=False)
Array([-2. +0.j, nan+nanj], dtype=complex64)
"""
check_arraylike("roots", p)
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
p = ensure_arraylike("roots", p)
p_arr = atleast_1d(*promote_dtypes_inexact(p))
del p
if p_arr.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
@ -225,14 +226,13 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
((3, 3), (3, 3, 1))
"""
if w is None:
check_arraylike("polyfit", x, y)
x_arr, y_arr = ensure_arraylike("polyfit", x, y)
else:
check_arraylike("polyfit", x, y, w)
x_arr, y_arr, w = ensure_arraylike("polyfit", x, y, w)
del x, y
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
# check arguments
x_arr, y_arr = asarray(x), asarray(y)
del x, y
if deg < 0:
raise ValueError("expected deg >= 0")
if x_arr.ndim != 1:
@ -254,8 +254,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
# apply weighting
if w is not None:
w, = promote_dtypes_inexact(w)
w_arr = asarray(w)
w_arr, = promote_dtypes_inexact(w)
if w_arr.ndim != 1:
raise TypeError("expected a 1-d array for weights")
if w_arr.shape[0] != y_arr.shape[0]:
@ -273,7 +272,8 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
c = (c.T/scale).T # broadcast scale coefficients
if full:
return c, resids, rank, s, asarray(rcond)
assert rcond is not None
return c, resids, rank, s, lax_internal.asarray(rcond)
elif cov:
Vbase = linalg.inv(dot(lhs.T, lhs))
Vbase /= outer(scale, scale)
@ -351,7 +351,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array:
>>> jnp.round(jnp.poly(x))
Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64)
"""
check_arraylike('poly', seq_of_zeros)
seq_of_zeros = ensure_arraylike('poly', seq_of_zeros)
seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros_arr = atleast_1d(seq_of_zeros)
del seq_of_zeros
@ -431,7 +431,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
[ 34., 53., 134.],
[ 8., 34., 76.]], dtype=float32)
"""
check_arraylike("polyval", p, x)
p, x = ensure_arraylike("polyval", p, x)
p_arr, x_arr = promote_dtypes_inexact(p, x)
del p, x
shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape)
@ -489,7 +489,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
Array([[ 5, 7, 9],
[10, 8, 6]], dtype=int32)
"""
check_arraylike("polyadd", a1, a2)
a1, a2 = ensure_arraylike("polyadd", a1, a2)
a1_arr, a2_arr = promote_dtypes(a1, a2)
del a1, a2
if a2_arr.shape[0] <= a1_arr.shape[0]:
@ -548,7 +548,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array
"""
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
check_arraylike("polyint", p, k)
p, k = ensure_arraylike("polyint", p, k)
p_arr, k_arr = promote_dtypes_inexact(p, k)
del p, k
if m < 0:
@ -605,7 +605,7 @@ def polyder(p: ArrayLike, m: int = 1) -> Array:
>>> jnp.polyder(p, m=2)
Array([ 12., -10.], dtype=float32)
"""
check_arraylike("polyder", p)
p = ensure_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p_arr, = promote_dtypes_inexact(p)
del p
@ -673,15 +673,15 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
>>> jnp.polymul(x3, x4, trim_leading_zeros=True)
Array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)
"""
check_arraylike("polymul", a1, a2)
a1, a2 = ensure_arraylike("polymul", a1, a2)
a1_arr, a2_arr = promote_dtypes_inexact(a1, a2)
del a1, a2
if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1):
a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f')
if len(a1_arr) == 0:
a1_arr = asarray([0], dtype=a2_arr.dtype)
a1_arr = zeros(1, dtype=a2_arr.dtype)
if len(a2_arr) == 0:
a2_arr = asarray([0], dtype=a1_arr.dtype)
a2_arr = zeros(1, dtype=a1_arr.dtype)
return convolve(a1_arr, a2_arr, mode='full')
@ -728,7 +728,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
(Array([1.25 , 1.4375], dtype=float32), Array([7.5625], dtype=float32))
"""
check_arraylike("polydiv", u, v)
u, v = ensure_arraylike("polydiv", u, v)
u_arr, v_arr = promote_dtypes_inexact(u, v)
del u, v
m = len(u_arr) - 1
@ -794,6 +794,6 @@ def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
Array([[5, 7, 9],
[6, 4, 2]], dtype=int32)
"""
check_arraylike("polysub", a1, a2)
a1, a2 = ensure_arraylike("polysub", a1, a2)
a1, a2 = promote_dtypes(a1, a2)
return polyadd(a1, -a2)