From 23c1d6291030da978e97a51c368c7eb5b786710a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 23 Jan 2025 08:48:13 -0800 Subject: [PATCH] internal: move more NumPy APIs to ensure_arraylike --- jax/_src/numpy/fft.py | 11 ++++------ jax/_src/numpy/polynomial.py | 42 ++++++++++++++++++------------------ 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index c0707ea1c..fee54347f 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -22,7 +22,7 @@ from jax import dtypes from jax import lax from jax._src.lib import xla_client from jax._src.util import safe_zip -from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact +from jax._src.numpy.util import ensure_arraylike, promote_dtypes_inexact from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import ufuncs, reductions from jax._src.sharding import Sharding @@ -49,8 +49,7 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int] | None, norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" - check_arraylike(full_name, a) - arr = jnp.asarray(a) + arr = ensure_arraylike(full_name, a) if s is not None: s = tuple(map(operator.index, s)) @@ -1287,8 +1286,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: >>> jnp.fft.ifftshift(shifted_freq) Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32) """ - check_arraylike("fftshift", x) - x = jnp.asarray(x) + x = ensure_arraylike("fftshift", x) shift: int | Sequence[int] if axes is None: axes = tuple(range(x.ndim)) @@ -1337,8 +1335,7 @@ def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: >>> jnp.fft.ifftshift(shifted_freq) Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32) """ - check_arraylike("ifftshift", x) - x = jnp.asarray(x) + x = ensure_arraylike("ifftshift", x) shift: int | Sequence[int] if axes is None: axes = tuple(range(x.ndim)) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 76616aa92..bba549054 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -23,15 +23,16 @@ from jax import jit from jax import lax from jax._src import dtypes from jax._src import core +from jax._src.lax import lax as lax_internal from jax._src.numpy.lax_numpy import ( - arange, argmin, array, asarray, atleast_1d, concatenate, convolve, + arange, argmin, array, atleast_1d, concatenate, convolve, diag, dot, finfo, full, ones, outer, roll, trim_zeros, trim_zeros_tol, vander, zeros) from jax._src.numpy.ufuncs import maximum, true_divide, sqrt from jax._src.numpy.reductions import all from jax._src.numpy import linalg from jax._src.numpy.util import ( - check_arraylike, promote_dtypes, promote_dtypes_inexact, _where) + ensure_arraylike, promote_dtypes, promote_dtypes_inexact, _where) from jax._src.typing import Array, ArrayLike from jax._src.util import set_module @@ -102,8 +103,8 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: >>> jnp.roots(coeffs, strip_zeros=False) Array([-2. +0.j, nan+nanj], dtype=complex64) """ - check_arraylike("roots", p) - p_arr = atleast_1d(promote_dtypes_inexact(p)[0]) + p = ensure_arraylike("roots", p) + p_arr = atleast_1d(*promote_dtypes_inexact(p)) del p if p_arr.ndim != 1: raise ValueError("Input must be a rank-1 array.") @@ -225,14 +226,13 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, ((3, 3), (3, 3, 1)) """ if w is None: - check_arraylike("polyfit", x, y) + x_arr, y_arr = ensure_arraylike("polyfit", x, y) else: - check_arraylike("polyfit", x, y, w) + x_arr, y_arr, w = ensure_arraylike("polyfit", x, y, w) + del x, y deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 # check arguments - x_arr, y_arr = asarray(x), asarray(y) - del x, y if deg < 0: raise ValueError("expected deg >= 0") if x_arr.ndim != 1: @@ -254,8 +254,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, # apply weighting if w is not None: - w, = promote_dtypes_inexact(w) - w_arr = asarray(w) + w_arr, = promote_dtypes_inexact(w) if w_arr.ndim != 1: raise TypeError("expected a 1-d array for weights") if w_arr.shape[0] != y_arr.shape[0]: @@ -273,7 +272,8 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, c = (c.T/scale).T # broadcast scale coefficients if full: - return c, resids, rank, s, asarray(rcond) + assert rcond is not None + return c, resids, rank, s, lax_internal.asarray(rcond) elif cov: Vbase = linalg.inv(dot(lhs.T, lhs)) Vbase /= outer(scale, scale) @@ -351,7 +351,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array: >>> jnp.round(jnp.poly(x)) Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64) """ - check_arraylike('poly', seq_of_zeros) + seq_of_zeros = ensure_arraylike('poly', seq_of_zeros) seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros) seq_of_zeros_arr = atleast_1d(seq_of_zeros) del seq_of_zeros @@ -431,7 +431,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: [ 34., 53., 134.], [ 8., 34., 76.]], dtype=float32) """ - check_arraylike("polyval", p, x) + p, x = ensure_arraylike("polyval", p, x) p_arr, x_arr = promote_dtypes_inexact(p, x) del p, x shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) @@ -489,7 +489,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: Array([[ 5, 7, 9], [10, 8, 6]], dtype=int32) """ - check_arraylike("polyadd", a1, a2) + a1, a2 = ensure_arraylike("polyadd", a1, a2) a1_arr, a2_arr = promote_dtypes(a1, a2) del a1, a2 if a2_arr.shape[0] <= a1_arr.shape[0]: @@ -548,7 +548,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array """ m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k - check_arraylike("polyint", p, k) + p, k = ensure_arraylike("polyint", p, k) p_arr, k_arr = promote_dtypes_inexact(p, k) del p, k if m < 0: @@ -605,7 +605,7 @@ def polyder(p: ArrayLike, m: int = 1) -> Array: >>> jnp.polyder(p, m=2) Array([ 12., -10.], dtype=float32) """ - check_arraylike("polyder", p) + p = ensure_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") p_arr, = promote_dtypes_inexact(p) del p @@ -673,15 +673,15 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - >>> jnp.polymul(x3, x4, trim_leading_zeros=True) Array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64) """ - check_arraylike("polymul", a1, a2) + a1, a2 = ensure_arraylike("polymul", a1, a2) a1_arr, a2_arr = promote_dtypes_inexact(a1, a2) del a1, a2 if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1): a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f') if len(a1_arr) == 0: - a1_arr = asarray([0], dtype=a2_arr.dtype) + a1_arr = zeros(1, dtype=a2_arr.dtype) if len(a2_arr) == 0: - a2_arr = asarray([0], dtype=a1_arr.dtype) + a2_arr = zeros(1, dtype=a1_arr.dtype) return convolve(a1_arr, a2_arr, mode='full') @@ -728,7 +728,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> >>> jnp.polydiv(x1, x2, trim_leading_zeros=True) (Array([1.25 , 1.4375], dtype=float32), Array([7.5625], dtype=float32)) """ - check_arraylike("polydiv", u, v) + u, v = ensure_arraylike("polydiv", u, v) u_arr, v_arr = promote_dtypes_inexact(u, v) del u, v m = len(u_arr) - 1 @@ -794,6 +794,6 @@ def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: Array([[5, 7, 9], [6, 4, 2]], dtype=int32) """ - check_arraylike("polysub", a1, a2) + a1, a2 = ensure_arraylike("polysub", a1, a2) a1, a2 = promote_dtypes(a1, a2) return polyadd(a1, -a2)