diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8f078ceee..7f3e05f95 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -710,9 +710,9 @@ def trunc(x: ArrayLike) -> Array: [ 1., -0., 1.], [-8., 5., 3.]], dtype=float32) """ - util.check_arraylike('trunc', x) + x = util.ensure_arraylike('trunc', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): - return lax_internal.asarray(x) + return x return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) @@ -827,8 +827,8 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, >>> jnp.convolve(x1, y1) Array([ 3. +1.j, 11. -7.j, 15.+10.j, 7. -8.j, 31. +8.j], dtype=complex64) """ - util.check_arraylike("convolve", a, v) - return _conv(asarray(a), asarray(v), mode=mode, op='convolve', + a, v = util.ensure_arraylike("convolve", a, v) + return _conv(a, v, mode=mode, op='convolve', precision=precision, preferred_element_type=preferred_element_type) @@ -913,8 +913,8 @@ def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, >>> jnp.correlate(x2, y2, mode='full') Array([ 3. +1.j, 3.+17.j, 18.+11.j, 27. +4.j, 8.-12.j], dtype=complex64) """ - util.check_arraylike("correlate", a, v) - return _conv(asarray(a), asarray(v), mode=mode, op='correlate', + a, v = util.ensure_arraylike("correlate", a, v) + return _conv(a, v, mode=mode, op='correlate', precision=precision, preferred_element_type=preferred_element_type) @@ -1556,8 +1556,8 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: [[8, 7], [6, 5]]], dtype=int32) """ - util.check_arraylike("flip", m) - return _flip(asarray(m), reductions._ensure_optional_axes(axis)) + arr = util.ensure_arraylike("flip", m) + return _flip(arr, reductions._ensure_optional_axes(axis)) @partial(jit, static_argnames=('axis',)) def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: @@ -1590,8 +1590,8 @@ def fliplr(m: ArrayLike) -> Array: Array([[2, 1], [4, 3]], dtype=int32) """ - util.check_arraylike("fliplr", m) - return _flip(asarray(m), 1) + arr = util.ensure_arraylike("fliplr", m) + return _flip(arr, 1) @export @@ -1617,8 +1617,8 @@ def flipud(m: ArrayLike) -> Array: Array([[3, 4], [1, 2]], dtype=int32) """ - util.check_arraylike("flipud", m) - return _flip(asarray(m), 0) + arr = util.ensure_arraylike("flipud", m) + return _flip(arr, 0) @export @@ -1786,8 +1786,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, Array([[ 4, -3, 7, -6], [ 5, -1, -3, -3]], dtype=int32) """ - util.check_arraylike("diff", a) - arr = asarray(a) + arr = util.ensure_arraylike("diff", a) n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff") if n == 0: @@ -1802,22 +1801,22 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, combined: list[Array] = [] if prepend is not None: - util.check_arraylike("diff", prepend) + prepend = util.ensure_arraylike("diff", prepend) if not ndim(prepend): shape = list(arr.shape) shape[axis] = 1 prepend = broadcast_to(prepend, tuple(shape)) - combined.append(asarray(prepend)) + combined.append(prepend) combined.append(arr) if append is not None: - util.check_arraylike("diff", append) + append = util.ensure_arraylike("diff", append) if not ndim(append): shape = list(arr.shape) shape[axis] = 1 append = broadcast_to(append, tuple(shape)) - combined.append(asarray(append)) + combined.append(append) if len(combined) > 1: arr = concatenate(combined, axis) @@ -1888,15 +1887,14 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, >>> jnp.ediff1d(a2) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) """ - util.check_arraylike("ediff1d", ary) - arr = ravel(ary) + arr = util.ensure_arraylike("ediff1d", ary).ravel() result = lax.sub(arr[1:], arr[:-1]) if to_begin is not None: - util.check_arraylike("ediff1d", to_begin) - result = concatenate((ravel(asarray(to_begin, dtype=arr.dtype)), result)) + to_begin = util.ensure_arraylike("ediff1d", to_begin) + result = concatenate((ravel(to_begin.astype(arr.dtype)), result)) if to_end is not None: - util.check_arraylike("ediff1d", to_end) - result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) + to_end = util.ensure_arraylike("ediff1d", to_end) + result = concatenate((result, ravel(to_end.astype(arr.dtype)))) return result @@ -2350,8 +2348,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: >>> jnp.ravel_multi_index(indices_2D, shape) Array([1, 3, 5], dtype=int32) """ - util.check_arraylike("unravel_index", indices) - indices_arr = asarray(indices) + indices_arr = util.ensure_arraylike("unravel_index", indices) # Note: we do not convert shape to an array, because it may be passed as a # tuple of weakly-typed values, and asarray() would strip these weak types. try: @@ -2480,8 +2477,8 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: >>> x.squeeze() Array([0, 1, 2], dtype=int32) """ - util.check_arraylike("squeeze", a) - return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) + arr = util.ensure_arraylike("squeeze", a) + return _squeeze(arr, _ensure_index_tuple(axis) if axis is not None else None) @partial(jit, static_argnames=('axis',), inline=True) def _squeeze(a: Array, axis: tuple[int, ...]) -> Array: @@ -2662,8 +2659,8 @@ def moveaxis(a: ArrayLike, source: int | Sequence[int], >>> a.transpose(2, 3, 1, 0).shape (4, 5, 3, 2) """ - util.check_arraylike("moveaxis", a) - return _moveaxis(asarray(a), _ensure_index_tuple(source), + arr = util.ensure_arraylike("moveaxis", a) + return _moveaxis(arr, _ensure_index_tuple(source), _ensure_index_tuple(destination)) @partial(jit, static_argnames=('source', 'destination'), inline=True) @@ -3266,8 +3263,7 @@ def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: def _split(op: str, ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: - util.check_arraylike(op, ary) - ary = asarray(ary) + ary = util.ensure_arraylike(op, ary) axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`") size = ary.shape[axis] if (isinstance(indices_or_sections, (tuple, list)) or @@ -3430,8 +3426,7 @@ def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) - :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections`` to be an integer that does not evenly divide the size of the array. """ - util.check_arraylike("hsplit", ary) - a = asarray(ary) + a = util.ensure_arraylike("hsplit", ary) return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1) @@ -3616,7 +3611,7 @@ def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: >>> jnp.round(x1) Array([10., 22., 12., 32.], dtype=float32) """ - util.check_arraylike("round", a) + a = util.ensure_arraylike("round", a) decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round") if out is not None: raise NotImplementedError("The 'out' argument to jnp.round is not supported.") @@ -3625,7 +3620,7 @@ def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: if decimals < 0: raise NotImplementedError( "integer np.round not implemented for decimals < 0") - return asarray(a) # no-op on integer types + return a # no-op on integer types def _round_float(x: ArrayLike) -> Array: if decimals == 0: @@ -3742,10 +3737,10 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, Array([ 0., 0., 1., inf, 2., -inf], dtype=float32) """ del copy - util.check_arraylike("nan_to_num", x) + x = util.ensure_arraylike("nan_to_num", x) dtype = _dtype(x) if not issubdtype(dtype, inexact): - return asarray(x) + return x if issubdtype(dtype, complexfloating): return lax.complex( nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf), @@ -3890,8 +3885,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None, >>> nonzero_jit(x, size=5, fill_value=len(x)) (Array([1, 3, 5, 6, 6], dtype=int32),) """ - util.check_arraylike("nonzero", a) - arr = asarray(a) + arr = util.ensure_arraylike("nonzero", a) del a if ndim(arr) == 0: raise ValueError("Calling nonzero on 0d arrays is not allowed. " @@ -4020,8 +4014,7 @@ def unwrap(p: ArrayLike, discont: ArrayLike | None = None, a larger discontinuity it adds factors of the period to the data. For periodic signals that satisfy this assumption, :func:`unwrap` can recover the original phased signal. """ - util.check_arraylike("unwrap", p) - p = asarray(p) + p = util.ensure_arraylike("unwrap", p) if issubdtype(p.dtype, np.complexfloating): raise ValueError("jnp.unwrap does not support complex inputs.") if p.shape[axis] == 0: @@ -4648,8 +4641,7 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: Array([[1, 2, 3], [4, 5, 6]], dtype=int32) """ - util.check_arraylike("unstack", x) - x = asarray(x) + x = util.ensure_arraylike("unstack", x) if x.ndim == 0: raise ValueError( "Unstack requires arrays with rank > 0, however a scalar array was " @@ -5712,8 +5704,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, >>> y.astype(int) # truncates fractional values Array([0, 0, 1], dtype=int32) """ - util.check_arraylike("astype", x) - x_arr = asarray(x) + x_arr = util.ensure_arraylike("astype", x) if dtype is None: dtype = dtypes.canonicalize_dtype(float_) @@ -6642,8 +6633,7 @@ def _eye(N: DimSize, M: DimSize | None = None, if isinstance(k, int): k = lax_internal._clip_int_to_valid_range(k, np.int32, "`argument `k` of jax.numpy.eye") - util.check_arraylike("eye", k) - offset = asarray(k) + offset = util.ensure_arraylike("eye", k) if not (offset.shape == () and dtypes.issubdtype(offset.dtype, np.integer)): raise ValueError(f"k must be a scalar integer; got {k}") N_int = core.canonicalize_dim(N, "argument of 'N' jnp.eye()") @@ -6935,14 +6925,14 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtypes.check_user_dtype_supported(dtype, "linspace") if num < 0: raise ValueError(f"Number of samples, {num}, must be non-negative.") - util.check_arraylike("linspace", start, stop) + start, stop = util.ensure_arraylike("linspace", start, stop) if dtype is None: dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) - start = asarray(start, dtype=computation_dtype) - stop = asarray(stop, dtype=computation_dtype) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop))) broadcast_start = broadcast_to(start, bounds_shape) @@ -7061,9 +7051,9 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) - util.check_arraylike("logspace", start, stop) - start = asarray(start, dtype=computation_dtype) - stop = asarray(stop, dtype=computation_dtype) + start, stop = util.ensure_arraylike("logspace", start, stop) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) lin = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis) return lax.convert_element_type(ufuncs.power(base, lin), dtype) @@ -7131,9 +7121,9 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) - util.check_arraylike("geomspace", start, stop) - start = asarray(start, dtype=computation_dtype) - stop = asarray(stop, dtype=computation_dtype) + start, stop = util.ensure_arraylike("geomspace", start, stop) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) sign = ufuncs.sign(start) res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign), @@ -7207,8 +7197,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, [[10 20 30] [10 20 30]] """ - util.check_arraylike("meshgrid", *xi) - args = [asarray(x) for x in xi] + args = list(util.ensure_arraylike_tuple("meshgrid", tuple(xi))) if not copy: raise ValueError("jax.numpy.meshgrid only supports copy=True") if indexing not in ["xy", "ij"]: @@ -7310,11 +7299,10 @@ def ix_(*args: ArrayLike) -> tuple[Array, ...]: Array([[ 20, 40], [100, 120]], dtype=int32) """ - util.check_arraylike("ix", *args) + args = util.ensure_arraylike_tuple("ix", args) n = len(args) output = [] for i, a in enumerate(args): - a = asarray(a) if len(a.shape) != 1: msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}" raise ValueError(msg.format(a.shape)) @@ -7457,14 +7445,12 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32) """ - util.check_arraylike("repeat", a) + arr = util.ensure_arraylike("repeat", a) core.is_dim(repeats) or util.check_arraylike("repeat", repeats) if axis is None: - a = ravel(a) + arr = arr.ravel() axis = 0 - else: - a = asarray(a) axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()") assert isinstance(axis, int) # to appease mypy @@ -7482,44 +7468,44 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, "value to `total_repeat_length`.") # Fast path for when repeats is a scalar. - if np.ndim(repeats) == 0 and ndim(a) != 0: - input_shape = shape(a) + if np.ndim(repeats) == 0 and ndim(arr) != 0: + input_shape = arr.shape axis = _canonicalize_axis(axis, len(input_shape)) aux_axis = axis + 1 aux_shape: list[DimSize] = list(input_shape) aux_shape.insert(aux_axis, operator.index(repeats) if core.is_constant_dim(repeats) else repeats) # type: ignore - a = lax.broadcast_in_dim( - a, aux_shape, [i for i in range(len(aux_shape)) if i != aux_axis]) + arr = lax.broadcast_in_dim( + arr, aux_shape, [i for i in range(len(aux_shape)) if i != aux_axis]) result_shape: list[DimSize] = list(input_shape) result_shape[axis] *= repeats - return reshape(a, result_shape) + return arr.reshape(result_shape) repeats = np.ravel(repeats) - if ndim(a) != 0: - repeats = np.broadcast_to(repeats, [shape(a)[axis]]) + if arr.ndim != 0: + repeats = np.broadcast_to(repeats, [arr.shape[axis]]) total_repeat_length = np.sum(repeats) else: repeats = ravel(repeats) - if ndim(a) != 0: - repeats = broadcast_to(repeats, [shape(a)[axis]]) + if arr.ndim != 0: + repeats = broadcast_to(repeats, [arr.shape[axis]]) # Special case when a is a scalar. - if ndim(a) == 0: + if arr.ndim == 0: if shape(repeats) == (1,): - return full([total_repeat_length], a) + return full([total_repeat_length], arr) else: raise ValueError('`repeat` with a scalar parameter `a` is only ' 'implemented for scalar values of the parameter `repeats`.') # Special case if total_repeat_length is zero. if total_repeat_length == 0: - result_shape = list(shape(a)) + result_shape = list(arr.shape) result_shape[axis] = 0 - return reshape(array([], dtype=_dtype(a)), result_shape) + return reshape(array([], dtype=arr.dtype), result_shape) # If repeats is on a zero sized axis, then return the array. - if shape(a)[axis] == 0: - return asarray(a) + if arr.shape[axis] == 0: + return arr # This implementation of repeat avoid having to instantiate a large. # intermediate tensor. @@ -7533,7 +7519,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, block_split_indicators = block_split_indicators.at[scatter_indices].add(1) # Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3] gather_indices = reductions.cumsum(block_split_indicators) - 1 - return take(a, gather_indices, axis=axis) + return take(arr, gather_indices, axis=axis) @export @@ -8213,9 +8199,7 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, raise NotImplementedError("JAX arrays are immutable, must use inplace=False") if wrap: raise NotImplementedError("wrap=True is not implemented, must use wrap=False") - util.check_arraylike("fill_diagonal", a, val) - a = asarray(a) - val = asarray(val) + a, val = util.ensure_arraylike("fill_diagonal", a, val) if a.ndim < 2: raise ValueError("array must be at least 2-d") if a.ndim > 2 and not all(n == a.shape[0] for n in a.shape[1:]): @@ -8685,11 +8669,10 @@ def delete( >>> jit_delete(a, indices, assume_unique_indices=True) Array([6, 8, 9], dtype=int32) """ - util.check_arraylike("delete", arr) + a = util.ensure_arraylike("delete", arr) if axis is None: - arr = ravel(arr) + a = a.ravel() axis = 0 - a = asarray(arr) axis = _canonicalize_axis(axis, a.ndim) # Case 1: obj is a static integer. @@ -8788,9 +8771,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, Array([[ 1, 10, 2, 3, 11], [ 4, 12, 5, 6, 13]], dtype=int32) """ - util.check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values) - a = asarray(arr) - values_arr = asarray(values) + a, _, values_arr = util.ensure_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values) if axis is None: a = ravel(a) @@ -8960,8 +8941,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, >>> jnp.prod(x, [0, 1], keepdims=True) Array([[720]], dtype=int32) """ - util.check_arraylike("apply_over_axes", a) - a_arr = asarray(a) + a_arr = util.ensure_arraylike("apply_over_axes", a) for axis in axes: b = func(a_arr, axis) if b.ndim == a_arr.ndim: @@ -9041,9 +9021,8 @@ def dot(a: ArrayLike, b: ArrayLike, *, >>> jnp.matmul(a, b).shape (3, 2, 1) """ - util.check_arraylike("dot", a, b) + a, b = util.ensure_arraylike("dot", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "dot") - a, b = asarray(a), asarray(b) if preferred_element_type is None: preferred_element_type, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True) else: @@ -9124,9 +9103,8 @@ def matmul(a: ArrayLike, b: ArrayLike, *, Array([[22, 28], [49, 64]], dtype=int32) """ - util.check_arraylike("matmul", a, b) + a, b = util.ensure_arraylike("matmul", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "matmul") - a, b = asarray(a), asarray(b) for i, x in enumerate((a, b)): if ndim(x) < 1: msg = (f"matmul input operand {i} must have ndim at least 1, " @@ -9368,8 +9346,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, >>> jnp.linalg.vecdot(a, b, axis=-1) Array([20, 47], dtype=int32) """ - util.check_arraylike("jnp.vecdot", x1, x2) - x1_arr, x2_arr = asarray(x1), asarray(x2) + x1_arr, x2_arr = util.ensure_arraylike("jnp.vecdot", x1, x2) if x1_arr.shape[axis] != x2_arr.shape[axis]: raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}") x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1) @@ -9454,9 +9431,8 @@ def tensordot(a: ArrayLike, b: ArrayLike, Array([[1, 2, 3], [2, 4, 6]], dtype=int32) """ - util.check_arraylike("tensordot", a, b) + a, b = util.ensure_arraylike("tensordot", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "tensordot") - a, b = asarray(a), asarray(b) a_ndim = ndim(a) b_ndim = ndim(b) @@ -10083,7 +10059,7 @@ def inner( >>> jnp.inner(a, b).shape (2, 5) """ - util.check_arraylike("inner", a, b) + a, b = util.ensure_arraylike("inner", a, b) if ndim(a) == 0 or ndim(b) == 0: a = asarray(a, dtype=preferred_element_type) b = asarray(b, dtype=preferred_element_type) @@ -10320,8 +10296,7 @@ def vander( [ 1, 3, 9, 27], [ 1, 4, 16, 64]], dtype=int32) """ - util.check_arraylike("vander", x) - x = asarray(x) + x = util.ensure_arraylike("vander", x) if x.ndim != 1: raise ValueError("x must be a one-dimensional array") N = x.shape[0] if N is None else core.concrete_or_error( @@ -10440,10 +10415,10 @@ def argmax(a: ArrayLike, axis: int | None = None, out: None = None, Array([[1], [0]], dtype=int32) """ - util.check_arraylike("argmax", a) + arr = util.ensure_arraylike("argmax", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.") - return _argmax(asarray(a), None if axis is None else operator.index(axis), + return _argmax(arr, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @partial(jit, static_argnames=('axis', 'keepdims'), inline=True) @@ -10496,10 +10471,10 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None, Array([[0], [2]], dtype=int32) """ - util.check_arraylike("argmin", a) + arr = util.ensure_arraylike("argmin", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.") - return _argmin(asarray(a), None if axis is None else operator.index(axis), + return _argmin(arr, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @partial(jit, static_argnames=('axis', 'keepdims'), inline=True) @@ -10693,17 +10668,15 @@ def sort( - :func:`jax.numpy.lexsort`: lexicographical sort of multiple arrays. - :func:`jax.lax.sort`: lower-level function wrapping XLA's Sort operator. """ - util.check_arraylike("sort", a) + arr = util.ensure_arraylike("sort", a) if kind is not None: raise TypeError("'kind' argument to sort is not supported. Use" " stable=True or stable=False to specify sort stability.") if order is not None: raise TypeError("'order' argument to sort is not supported.") if axis is None: - arr = ravel(a) + arr = arr.ravel() axis = 0 - else: - arr = asarray(a) dimension = _canonicalize_axis(axis, arr.ndim) result = lax.sort(arr, dimension=dimension, is_stable=stable) return lax.rev(result, dimensions=[dimension]) if descending else result @@ -10742,8 +10715,8 @@ def sort_complex(a: ArrayLike) -> Array: Array([[3.+0.j, 4.+0.j, 5.+0.j], [2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64) """ - util.check_arraylike("sort_complex", a) - a = lax.sort(asarray(a)) + a = util.ensure_arraylike("sort_complex", a) + a = lax.sort(a) return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) @@ -10810,9 +10783,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A Array([[0, 1, 0, 1], [1, 0, 1, 0]], dtype=int32) """ - key_tuple = tuple(keys) - util.check_arraylike("lexsort", *key_tuple) - key_arrays = tuple(asarray(k) for k in key_tuple) + key_arrays = util.ensure_arraylike_tuple("lexsort", tuple(keys)) if len(key_arrays) == 0: raise TypeError("need sequence of keys with len > 0 in lexsort") if len({shape(key) for key in key_arrays}) > 1: @@ -10881,18 +10852,15 @@ def argsort( - :func:`jax.numpy.lexsort`: lexicographical sort of multiple arrays. - :func:`jax.lax.sort`: lower-level function wrapping XLA's Sort operator. """ - util.check_arraylike("argsort", a) - arr = asarray(a) + arr = util.ensure_arraylike("argsort", a) if kind is not None: raise TypeError("'kind' argument to argsort is not supported. Use" " stable=True or stable=False to specify sort stability.") if order is not None: raise TypeError("'order' argument to argsort is not supported.") if axis is None: - arr = ravel(arr) + arr = arr.ravel() axis = 0 - else: - arr = asarray(a) dimension = _canonicalize_axis(axis, arr.ndim) use_64bit_index = not core.is_constant_dim(arr.shape[dimension]) or arr.shape[dimension] >= (1 << 31) iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, dimension) @@ -10959,8 +10927,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: order is arbitrary and implementation-dependent. """ # TODO(jakevdp): handle NaN values like numpy. - util.check_arraylike("partition", a) - arr = asarray(a) + arr = util.ensure_arraylike("partition", a) if issubdtype(arr.dtype, np.complexfloating): raise NotImplementedError("jnp.partition for complex dtype is not implemented.") axis = _canonicalize_axis(axis, arr.ndim) @@ -11031,8 +10998,7 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: order is arbitrary and implementation-dependent. """ # TODO(jakevdp): handle NaN values like numpy. - util.check_arraylike("partition", a) - arr = asarray(a) + arr = util.ensure_arraylike("partition", a) if issubdtype(arr.dtype, np.complexfloating): raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.") axis = _canonicalize_axis(axis, arr.ndim) @@ -11123,8 +11089,7 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], [ 9, 10, 11, 8], [ 1, 2, 3, 0]], dtype=int32) """ - util.check_arraylike("roll", a) - arr = asarray(a) + arr = util.ensure_arraylike("roll", a) if axis is None: return roll(arr.ravel(), shift, 0).reshape(arr.shape) axis = _ensure_index_tuple(axis) @@ -11262,8 +11227,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar Array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0], [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]], dtype=uint8) """ - util.check_arraylike("packbits", a) - arr = asarray(a) + arr = util.ensure_arraylike("packbits", a) if not (issubdtype(arr.dtype, integer) or issubdtype(arr.dtype, bool_)): raise TypeError('Expected an input array of integer or boolean data type') if bitorder not in ['little', 'big']: @@ -11357,9 +11321,8 @@ def unpackbits( >>> jnp.unpackbits(vals, count=-5) # specify 5 bits to be trimmed Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8) """ - util.check_arraylike("unpackbits", a) - arr = asarray(a) - if _dtype(a) != uint8: + arr = util.ensure_arraylike("unpackbits", a) + if arr.dtype != uint8: raise TypeError("Expected an input array of unsigned byte data type") if bitorder not in ['little', 'big']: raise ValueError("'order' must be either 'little' or 'big'") @@ -11473,9 +11436,7 @@ def _take(a, indices, axis: int | None = None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None): if out is not None: raise NotImplementedError("The 'out' argument to jnp.take is not supported.") - util.check_arraylike("take", a, indices) - a = asarray(a) - indices = asarray(indices) + a, indices = util.ensure_arraylike("take", a, indices) if axis is None: a = ravel(a) @@ -11618,8 +11579,7 @@ def take_along_axis( Array([[3], [2]], dtype=int32) """ - util.check_arraylike("take_along_axis", arr, indices) - a = asarray(arr) + a, indices = util.ensure_arraylike("take_along_axis", arr, indices) index_dtype = dtypes.dtype(indices) idx_shape = shape(indices) if not dtypes.issubdtype(index_dtype, integer): @@ -11791,10 +11751,7 @@ def put_along_axis( "jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays" "are immutable. Pass inplace=False to instead return an updated array.") - util.check_arraylike("put_along_axis", arr, indices, values) - arr = asarray(arr) - indices = asarray(indices) - values = asarray(values) + arr, indices, values = util.ensure_arraylike("put_along_axis", arr, indices, values) original_axis = axis original_arr_shape = arr.shape @@ -12814,17 +12771,17 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, [ 6, 0, 0, 0], [ 9, 12, 0, 0]], dtype=int32) """ - util.check_arraylike("compress", condition, a, fill_value) - condition_arr = asarray(condition).astype(bool) + condition_arr, arr, fill_value = util.ensure_arraylike("compress", condition, a, fill_value) + condition_arr = condition_arr.astype(bool) if out is not None: raise NotImplementedError("The 'out' argument to jnp.compress is not supported.") if condition_arr.ndim != 1: raise ValueError("condition must be a 1D array") if axis is None: axis = 0 - arr = ravel(a) + arr = ravel(arr) else: - arr = moveaxis(a, axis, 0) + arr = moveaxis(arr, axis, 0) condition_arr, extra = condition_arr[:arr.shape[0]], condition_arr[arr.shape[0]:] arr = arr[:condition_arr.shape[0]] @@ -12965,7 +12922,7 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, w: Array | None = None if fweights is not None: - util.check_arraylike("cov", fweights) + fweights = util.ensure_arraylike("cov", fweights) if ndim(fweights) > 1: raise RuntimeError("cannot handle multidimensional fweights") if shape(fweights)[0] != X.shape[1]: @@ -12973,16 +12930,16 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, if not issubdtype(_dtype(fweights), integer): raise TypeError("fweights must be integer.") # Ensure positive fweights; note that numpy raises an error on negative fweights. - w = asarray(ufuncs.abs(fweights)) + w = abs(fweights) if aweights is not None: - util.check_arraylike("cov", aweights) + aweights = util.ensure_arraylike("cov", aweights) if ndim(aweights) > 1: raise RuntimeError("cannot handle multidimensional aweights") if shape(aweights)[0] != X.shape[1]: raise RuntimeError("incompatible numbers of samples and aweights") # Ensure positive aweights: note that numpy raises an error for negative aweights. - aweights = ufuncs.abs(aweights) - w = asarray(aweights) if w is None else w * asarray(aweights) + aweights = abs(aweights) + w = aweights if w is None else w * aweights avg, w_sum = reductions.average(X, axis=1, weights=w, returned=True) w_sum = w_sum[0] @@ -13218,7 +13175,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', 'sort': _searchsorted_via_sort, 'compare_all': _searchsorted_via_compare_all, }[method] - return impl(asarray(a), asarray(v), side, dtype) # type: ignore + return impl(a, v, side, dtype) # type: ignore @export @@ -13261,9 +13218,8 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, >>> jnp.digitize(x, bins) Array([2, 1, 1, 2, 0, 0], dtype=int32) """ - util.check_arraylike("digitize", x, bins) + x, bins_arr = util.ensure_arraylike("digitize", x, bins) right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()") - bins_arr = asarray(bins) if bins_arr.ndim != 1: raise ValueError(f"digitize: bins must be a 1-dimensional array; got {bins=}") if bins_arr.shape[0] == 0: @@ -13347,7 +13303,7 @@ def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], >>> jnp.piecewise(x, condlist, funclist) Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32) """ - util.check_arraylike("piecewise", x) + x_arr = util.ensure_arraylike("piecewise", x) nc, nf = len(condlist), len(funclist) if nf == nc + 1: funclist = funclist[-1:] + funclist[:-1] @@ -13357,7 +13313,7 @@ def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}") consts = {i: c for i, c in enumerate(funclist) if not callable(c)} funcs = {i: f for i, f in enumerate(funclist) if callable(f)} - return _piecewise(asarray(x), asarray(condlist, dtype=bool_), consts, + return _piecewise(x_arr, asarray(condlist, dtype=bool_), consts, frozenset(funcs.items()), # dict is not hashable. *args, **kw) @@ -13444,7 +13400,8 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, [0, 5, 0, 0, 1], [0, 0, 3, 0, 0]], dtype=int32) """ - util.check_arraylike("place", arr, mask, vals) + data, mask_arr, vals_arr = util.ensure_arraylike("place", arr, mask, vals) + vals_arr = vals_arr.ravel() data, mask_arr, vals_arr = asarray(arr), asarray(mask), ravel(vals) if inplace: raise ValueError( @@ -13526,8 +13483,9 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32) """ - util.check_arraylike("put", a, ind, v) - arr, ind_arr, v_arr = asarray(a), ravel(ind), ravel(v) + arr, ind_arr, _ = util.ensure_arraylike("put", a, ind, v) + ind_arr = ind_arr.ravel() + v_arr = ravel(v) if not arr.size or not ind_arr.size or not v_arr.size: return arr v_arr = _tile_to_size(v_arr, len(ind_arr)) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 6ca5b2f0a..7a5adfc40 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -15,7 +15,7 @@ from __future__ import annotations from collections.abc import Sequence from functools import partial -from typing import Any +from typing import Any, overload import warnings @@ -133,6 +133,42 @@ def _arraylike(x: ArrayLike) -> bool: hasattr(x, '__jax_array__') or np.isscalar(x)) +def _arraylike_asarray(x: Any) -> Array: + """Convert an array-like object to an array.""" + if hasattr(x, '__jax_array__'): + x = x.__jax_array__() + elif isinstance(x, (bool, int, float, complex)): + x = dtypes.coerce_to_array(x) + return lax.asarray(x) + + +@overload +def ensure_arraylike(fun_name: str, /) -> tuple[()]: ... +@overload +def ensure_arraylike(fun_name: str, a1: Any, /) -> Array: ... +@overload +def ensure_arraylike(fun_name: str, a1: Any, a2: Any, /) -> tuple[Array, Array]: ... +@overload +def ensure_arraylike(fun_name: str, a1: Any, a2: Any, a3: Any, /) -> tuple[Array, Array, Array]: ... +@overload +def ensure_arraylike(fun_name: str, a1: Any, a2: Any, a3: Any, a4: Any, /, *args: Any) -> tuple[Array, ...]: ... +def ensure_arraylike(fun_name: str, /, *args: Any) -> Array | tuple[Array, ...]: + """Check that arguments are arraylike and convert them to arrays.""" + check_arraylike(fun_name, *args) + if len(args) == 1: + return _arraylike_asarray(args[0]) # pytype: disable=bad-return-type + return tuple(_arraylike_asarray(arg) for arg in args) # pytype: disable=bad-return-type + + +def ensure_arraylike_tuple(fun_name: str, tup: tuple[Any, ...]) -> tuple[Array, ...]: + """Check that argument elements are arraylike and convert to a tuple of arrays. + + This is useful because ensure_arraylike with a single argument returns a single array. + """ + check_arraylike(fun_name, *tup) + return tuple(_arraylike_asarray(arg) for arg in tup) + + def check_arraylike(fun_name: str, *args: Any, emit_warning=False, stacklevel=3): """Check if all args fit JAX's definition of arraylike.""" assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"