mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
internal: move more NumPy APIs to ensure_arraylike
This commit is contained in:
parent
6b95ad0a53
commit
23c1d62910
@ -22,7 +22,7 @@ from jax import dtypes
|
||||
from jax import lax
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
|
||||
from jax._src.numpy.util import ensure_arraylike, promote_dtypes_inexact
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import ufuncs, reductions
|
||||
from jax._src.sharding import Sharding
|
||||
@ -49,8 +49,7 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
|
||||
s: Shape | None, axes: Sequence[int] | None,
|
||||
norm: str | None) -> Array:
|
||||
full_name = f"jax.numpy.fft.{func_name}"
|
||||
check_arraylike(full_name, a)
|
||||
arr = jnp.asarray(a)
|
||||
arr = ensure_arraylike(full_name, a)
|
||||
|
||||
if s is not None:
|
||||
s = tuple(map(operator.index, s))
|
||||
@ -1287,8 +1286,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
|
||||
>>> jnp.fft.ifftshift(shifted_freq)
|
||||
Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32)
|
||||
"""
|
||||
check_arraylike("fftshift", x)
|
||||
x = jnp.asarray(x)
|
||||
x = ensure_arraylike("fftshift", x)
|
||||
shift: int | Sequence[int]
|
||||
if axes is None:
|
||||
axes = tuple(range(x.ndim))
|
||||
@ -1337,8 +1335,7 @@ def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
|
||||
>>> jnp.fft.ifftshift(shifted_freq)
|
||||
Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32)
|
||||
"""
|
||||
check_arraylike("ifftshift", x)
|
||||
x = jnp.asarray(x)
|
||||
x = ensure_arraylike("ifftshift", x)
|
||||
shift: int | Sequence[int]
|
||||
if axes is None:
|
||||
axes = tuple(range(x.ndim))
|
||||
|
@ -23,15 +23,16 @@ from jax import jit
|
||||
from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax._src import core
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
arange, argmin, array, asarray, atleast_1d, concatenate, convolve,
|
||||
arange, argmin, array, atleast_1d, concatenate, convolve,
|
||||
diag, dot, finfo, full, ones, outer, roll, trim_zeros,
|
||||
trim_zeros_tol, vander, zeros)
|
||||
from jax._src.numpy.ufuncs import maximum, true_divide, sqrt
|
||||
from jax._src.numpy.reductions import all
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import (
|
||||
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where)
|
||||
ensure_arraylike, promote_dtypes, promote_dtypes_inexact, _where)
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import set_module
|
||||
|
||||
@ -102,8 +103,8 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
|
||||
>>> jnp.roots(coeffs, strip_zeros=False)
|
||||
Array([-2. +0.j, nan+nanj], dtype=complex64)
|
||||
"""
|
||||
check_arraylike("roots", p)
|
||||
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
|
||||
p = ensure_arraylike("roots", p)
|
||||
p_arr = atleast_1d(*promote_dtypes_inexact(p))
|
||||
del p
|
||||
if p_arr.ndim != 1:
|
||||
raise ValueError("Input must be a rank-1 array.")
|
||||
@ -225,14 +226,13 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
|
||||
((3, 3), (3, 3, 1))
|
||||
"""
|
||||
if w is None:
|
||||
check_arraylike("polyfit", x, y)
|
||||
x_arr, y_arr = ensure_arraylike("polyfit", x, y)
|
||||
else:
|
||||
check_arraylike("polyfit", x, y, w)
|
||||
x_arr, y_arr, w = ensure_arraylike("polyfit", x, y, w)
|
||||
del x, y
|
||||
deg = core.concrete_or_error(int, deg, "deg must be int")
|
||||
order = deg + 1
|
||||
# check arguments
|
||||
x_arr, y_arr = asarray(x), asarray(y)
|
||||
del x, y
|
||||
if deg < 0:
|
||||
raise ValueError("expected deg >= 0")
|
||||
if x_arr.ndim != 1:
|
||||
@ -254,8 +254,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
|
||||
|
||||
# apply weighting
|
||||
if w is not None:
|
||||
w, = promote_dtypes_inexact(w)
|
||||
w_arr = asarray(w)
|
||||
w_arr, = promote_dtypes_inexact(w)
|
||||
if w_arr.ndim != 1:
|
||||
raise TypeError("expected a 1-d array for weights")
|
||||
if w_arr.shape[0] != y_arr.shape[0]:
|
||||
@ -273,7 +272,8 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
|
||||
c = (c.T/scale).T # broadcast scale coefficients
|
||||
|
||||
if full:
|
||||
return c, resids, rank, s, asarray(rcond)
|
||||
assert rcond is not None
|
||||
return c, resids, rank, s, lax_internal.asarray(rcond)
|
||||
elif cov:
|
||||
Vbase = linalg.inv(dot(lhs.T, lhs))
|
||||
Vbase /= outer(scale, scale)
|
||||
@ -351,7 +351,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array:
|
||||
>>> jnp.round(jnp.poly(x))
|
||||
Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64)
|
||||
"""
|
||||
check_arraylike('poly', seq_of_zeros)
|
||||
seq_of_zeros = ensure_arraylike('poly', seq_of_zeros)
|
||||
seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros)
|
||||
seq_of_zeros_arr = atleast_1d(seq_of_zeros)
|
||||
del seq_of_zeros
|
||||
@ -431,7 +431,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
|
||||
[ 34., 53., 134.],
|
||||
[ 8., 34., 76.]], dtype=float32)
|
||||
"""
|
||||
check_arraylike("polyval", p, x)
|
||||
p, x = ensure_arraylike("polyval", p, x)
|
||||
p_arr, x_arr = promote_dtypes_inexact(p, x)
|
||||
del p, x
|
||||
shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape)
|
||||
@ -489,7 +489,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
|
||||
Array([[ 5, 7, 9],
|
||||
[10, 8, 6]], dtype=int32)
|
||||
"""
|
||||
check_arraylike("polyadd", a1, a2)
|
||||
a1, a2 = ensure_arraylike("polyadd", a1, a2)
|
||||
a1_arr, a2_arr = promote_dtypes(a1, a2)
|
||||
del a1, a2
|
||||
if a2_arr.shape[0] <= a1_arr.shape[0]:
|
||||
@ -548,7 +548,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array
|
||||
"""
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
|
||||
k = 0 if k is None else k
|
||||
check_arraylike("polyint", p, k)
|
||||
p, k = ensure_arraylike("polyint", p, k)
|
||||
p_arr, k_arr = promote_dtypes_inexact(p, k)
|
||||
del p, k
|
||||
if m < 0:
|
||||
@ -605,7 +605,7 @@ def polyder(p: ArrayLike, m: int = 1) -> Array:
|
||||
>>> jnp.polyder(p, m=2)
|
||||
Array([ 12., -10.], dtype=float32)
|
||||
"""
|
||||
check_arraylike("polyder", p)
|
||||
p = ensure_arraylike("polyder", p)
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
|
||||
p_arr, = promote_dtypes_inexact(p)
|
||||
del p
|
||||
@ -673,15 +673,15 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
|
||||
>>> jnp.polymul(x3, x4, trim_leading_zeros=True)
|
||||
Array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)
|
||||
"""
|
||||
check_arraylike("polymul", a1, a2)
|
||||
a1, a2 = ensure_arraylike("polymul", a1, a2)
|
||||
a1_arr, a2_arr = promote_dtypes_inexact(a1, a2)
|
||||
del a1, a2
|
||||
if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1):
|
||||
a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f')
|
||||
if len(a1_arr) == 0:
|
||||
a1_arr = asarray([0], dtype=a2_arr.dtype)
|
||||
a1_arr = zeros(1, dtype=a2_arr.dtype)
|
||||
if len(a2_arr) == 0:
|
||||
a2_arr = asarray([0], dtype=a1_arr.dtype)
|
||||
a2_arr = zeros(1, dtype=a1_arr.dtype)
|
||||
return convolve(a1_arr, a2_arr, mode='full')
|
||||
|
||||
|
||||
@ -728,7 +728,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
|
||||
>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
|
||||
(Array([1.25 , 1.4375], dtype=float32), Array([7.5625], dtype=float32))
|
||||
"""
|
||||
check_arraylike("polydiv", u, v)
|
||||
u, v = ensure_arraylike("polydiv", u, v)
|
||||
u_arr, v_arr = promote_dtypes_inexact(u, v)
|
||||
del u, v
|
||||
m = len(u_arr) - 1
|
||||
@ -794,6 +794,6 @@ def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
|
||||
Array([[5, 7, 9],
|
||||
[6, 4, 2]], dtype=int32)
|
||||
"""
|
||||
check_arraylike("polysub", a1, a2)
|
||||
a1, a2 = ensure_arraylike("polysub", a1, a2)
|
||||
a1, a2 = promote_dtypes(a1, a2)
|
||||
return polyadd(a1, -a2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user