From a4412e2715b8692aa1438ed3aedfdad57b0666b1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 28 Feb 2023 14:50:33 -0800 Subject: [PATCH] Remove internal ndarray type name. Use Array throughout. jax.numpy.ndarray remains an exported alias for jax.Array. PiperOrigin-RevId: 513046188 --- jax/_src/numpy/lax_numpy.py | 31 +++++++++++++++---------------- jax/_src/numpy/ndarray.py | 17 ----------------- jax/_src/numpy/reductions.py | 5 ++--- jax/_src/numpy/util.py | 5 ++--- jax/numpy/__init__.py | 3 ++- 5 files changed, 21 insertions(+), 40 deletions(-) delete mode 100644 jax/_src/numpy/ndarray.py diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 663a82e75..8c18dc5a1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -57,7 +57,6 @@ from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, from jax._src.lax import lax as lax_internal from jax._src.lib import pmap_lib from jax._src.lib import xla_client -from jax._src.numpy.ndarray import ndarray from jax._src.numpy.reductions import ( # noqa: F401 _ensure_optional_axes, _reduction_dims, alltrue, amin, amax, any, all, average, count_nonzero, cumsum, cumprod, cumproduct, @@ -288,7 +287,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: >>> _convert_and_clip_integer(val, 'int32') Array(2147483647, dtype=int32) """ - val = val if isinstance(val, ndarray) else asarray(val) + val = val if isinstance(val, Array) else asarray(val) dtype = dtypes.canonicalize_dtype(dtype) if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)): raise TypeError("_convert_and_clip_integer only accepts integer dtypes.") @@ -1206,7 +1205,7 @@ def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], for i_s in indices_or_sections], np.int64) split_indices = np.concatenate([[np.int64(0)], indices_or_sections, [np.int64(size)]]) - elif (isinstance(indices_or_sections, (np.ndarray, ndarray)) and + elif (isinstance(indices_or_sections, (np.ndarray, Array)) and indices_or_sections.ndim > 0): indices_or_sections = np.array( [core.concrete_or_error(np.int64, i_s, f"in jax.numpy.{op} argument 1") @@ -1777,7 +1776,7 @@ def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], raise ValueError("Need at least one array to stack.") if out is not None: raise NotImplementedError("The 'out' argument to jnp.stack is not supported.") - if isinstance(arrays, (np.ndarray, ndarray)): + if isinstance(arrays, (np.ndarray, Array)): axis = _canonicalize_axis(axis, arrays.ndim) return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: @@ -1826,7 +1825,7 @@ def _concatenate_array(arr: ArrayLike, axis: Optional[int], @_wraps(np.concatenate) def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array: - if isinstance(arrays, (np.ndarray, ndarray)): + if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) _stackable(*arrays) or _check_arraylike("concatenate", *arrays) if not len(arrays): @@ -1855,7 +1854,7 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], @_wraps(np.vstack) def vstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], dtype: Optional[DTypeLike] = None) -> Array: - if isinstance(tup, (np.ndarray, ndarray)): + if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_2d)(tup) else: arrs = [atleast_2d(m) for m in tup] @@ -1866,7 +1865,7 @@ row_stack = vstack @_wraps(np.hstack) def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], dtype: Optional[DTypeLike] = None) -> Array: - if isinstance(tup, (np.ndarray, ndarray)): + if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_1d)(tup) arr0_ndim = arrs.ndim - 1 else: @@ -1878,7 +1877,7 @@ def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], @_wraps(np.dstack) def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], dtype: Optional[DTypeLike] = None) -> Array: - if isinstance(tup, (np.ndarray, ndarray)): + if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_3d)(tup) else: arrs = [atleast_3d(m) for m in tup] @@ -1887,7 +1886,7 @@ def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]], @_wraps(np.column_stack) def column_stack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]]) -> Array: - if isinstance(tup, (np.ndarray, ndarray)): + if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup else: arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)] @@ -2926,7 +2925,7 @@ def insert(arr, obj, values, axis=None): raise ValueError("jnp.insert(): obj must be a slice, a one-dimensional " f"array, or a scalar; got {obj}") if not np.issubdtype(indices.dtype, np.integer): - if indices.size == 0 and not isinstance(obj, ndarray): + if indices.size == 0 and not isinstance(obj, Array): indices = indices.astype(int) else: # Note: np.insert allows boolean inputs but the behavior is deprecated. @@ -4089,7 +4088,7 @@ def _index_to_gather(x_shape, idx, normalize_indices=True): idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None] advanced_pairs = ( (asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones) - if isscalar(e) or isinstance(e, (Sequence, ndarray, np.ndarray))) + if isscalar(e) or isinstance(e, (Sequence, Array, np.ndarray))) if normalize_indices: advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j) for e, i, j in advanced_pairs) @@ -4289,7 +4288,7 @@ def _index_to_gather(x_shape, idx, normalize_indices=True): def _should_unpack_list_index(x): """Helper for _eliminate_deprecated_list_indexing.""" - return (isinstance(x, (np.ndarray, ndarray)) and np.ndim(x) != 0 + return (isinstance(x, (np.ndarray, Array)) and np.ndim(x) != 0 or isinstance(x, (Sequence, slice)) or x is Ellipsis or x is None) @@ -4298,7 +4297,7 @@ def _eliminate_deprecated_list_indexing(idx): # non-tuple sequence containing slice objects, [Ellipses, or newaxis # objects]". Detects this and raises a TypeError. if not isinstance(idx, tuple): - if isinstance(idx, Sequence) and not isinstance(idx, (ndarray, np.ndarray)): + if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray)): # As of numpy 1.16, some non-tuple sequences of indices result in a warning, while # others are converted to arrays, based on a set of somewhat convoluted heuristics # (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343) @@ -4393,7 +4392,7 @@ def _is_int_arraylike(x): def _is_scalar(x): """Checks if a Python or NumPy scalar.""" - return np.isscalar(x) or (isinstance(x, (np.ndarray, ndarray)) + return np.isscalar(x) or (isinstance(x, (np.ndarray, Array)) and np.ndim(x) == 0) def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'): @@ -5105,7 +5104,7 @@ def put(*args, **kwargs): # functions, which can themselves handle instances from any of these classes. _scalar_types = (int, float, complex, np.generic) -_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, ndarray) +_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array) _rejected_binop_types = (list, tuple, set, dict) def _defer_to_unrecognized_arg(opchar, binary_op, swap=False): @@ -5327,7 +5326,7 @@ class _IndexUpdateHelper: def __repr__(self): return f"_IndexUpdateHelper({repr(self.array)})" -ndarray.at.__doc__ = _IndexUpdateHelper.__doc__ +Array.at.__doc__ = _IndexUpdateHelper.__doc__ _power_fn = power _divide_fn = divide diff --git a/jax/_src/numpy/ndarray.py b/jax/_src/numpy/ndarray.py deleted file mode 100644 index bb1300081..000000000 --- a/jax/_src/numpy/ndarray.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2022 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ['ndarray'] - -from jax._src.typing import Array as ndarray diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 61d5be671..2b6ff3a6b 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -25,7 +25,6 @@ from jax import lax from jax._src import api from jax._src import core from jax._src import dtypes -from jax._src.numpy.ndarray import ndarray from jax._src.numpy.util import ( _broadcast_to, _check_arraylike, _complex_elem_type, _promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps) @@ -44,7 +43,7 @@ Axis = Union[None, int, Sequence[int]] def _asarray(a: ArrayLike) -> Array: # simplified version of jnp.asarray() for local use. - return a if isinstance(a, ndarray) else api.device_put(a) + return a if isinstance(a, Array) else api.device_put(a) def _isscalar(element: Any) -> bool: if hasattr(element, '__jax_array__'): @@ -92,7 +91,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: raise ValueError(f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") - a = a if isinstance(a, ndarray) else _asarray(a) + a = a if isinstance(a, Array) else _asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 450707e80..d4679a173 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -25,7 +25,6 @@ from jax._src import api from jax._src import core from jax._src.config import config from jax._src.lax import lax -from jax._src.numpy.ndarray import ndarray from jax._src.util import safe_zip, safe_map from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape @@ -324,7 +323,7 @@ def _complex_elem_type(dtype: DTypeLike) -> DType: def _arraylike(x: ArrayLike) -> bool: - return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or + return (isinstance(x, np.ndarray) or isinstance(x, Array) or hasattr(x, '__jax_array__') or np.isscalar(x)) @@ -393,7 +392,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array: if hasattr(arr, "broadcast_to"): return arr.broadcast_to(shape) # type: ignore[union-attr] _check_arraylike("broadcast_to", arr) - arr = arr if isinstance(arr, ndarray) else _asarray(arr) + arr = arr if isinstance(arr, Array) else _asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape,) shape = core.canonicalize_shape(shape) # check that shape is concrete diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 9f00edeac..7c9467ef3 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -20,6 +20,8 @@ from jax.numpy import linalg as linalg from jax._src.device_array import DeviceArray as DeviceArray +from jax._src.basearray import Array as ndarray + from jax._src.numpy.lax_numpy import ( ComplexWarning as ComplexWarning, NINF as NINF, @@ -178,7 +180,6 @@ from jax._src.numpy.lax_numpy import ( nanmedian as nanmedian, nanpercentile as nanpercentile, nanquantile as nanquantile, - ndarray as ndarray, ndim as ndim, newaxis as newaxis, nonzero as nonzero,