Remove internal ndarray type name. Use Array throughout.

jax.numpy.ndarray remains an exported alias for jax.Array.

PiperOrigin-RevId: 513046188
This commit is contained in:
Peter Hawkins 2023-02-28 14:50:33 -08:00 committed by jax authors
parent 52a7701dda
commit a4412e2715
5 changed files with 21 additions and 40 deletions

View File

@ -57,7 +57,6 @@ from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
from jax._src.lax import lax as lax_internal
from jax._src.lib import pmap_lib
from jax._src.lib import xla_client
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.reductions import ( # noqa: F401
_ensure_optional_axes, _reduction_dims,
alltrue, amin, amax, any, all, average, count_nonzero, cumsum, cumprod, cumproduct,
@ -288,7 +287,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
>>> _convert_and_clip_integer(val, 'int32')
Array(2147483647, dtype=int32)
"""
val = val if isinstance(val, ndarray) else asarray(val)
val = val if isinstance(val, Array) else asarray(val)
dtype = dtypes.canonicalize_dtype(dtype)
if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)):
raise TypeError("_convert_and_clip_integer only accepts integer dtypes.")
@ -1206,7 +1205,7 @@ def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
for i_s in indices_or_sections], np.int64)
split_indices = np.concatenate([[np.int64(0)], indices_or_sections,
[np.int64(size)]])
elif (isinstance(indices_or_sections, (np.ndarray, ndarray)) and
elif (isinstance(indices_or_sections, (np.ndarray, Array)) and
indices_or_sections.ndim > 0):
indices_or_sections = np.array(
[core.concrete_or_error(np.int64, i_s, f"in jax.numpy.{op} argument 1")
@ -1777,7 +1776,7 @@ def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
raise ValueError("Need at least one array to stack.")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
if isinstance(arrays, (np.ndarray, ndarray)):
if isinstance(arrays, (np.ndarray, Array)):
axis = _canonicalize_axis(axis, arrays.ndim)
return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
else:
@ -1826,7 +1825,7 @@ def _concatenate_array(arr: ArrayLike, axis: Optional[int],
@_wraps(np.concatenate)
def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(arrays, (np.ndarray, ndarray)):
if isinstance(arrays, (np.ndarray, Array)):
return _concatenate_array(arrays, axis, dtype=dtype)
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
if not len(arrays):
@ -1855,7 +1854,7 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
@_wraps(np.vstack)
def vstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(tup, (np.ndarray, ndarray)):
if isinstance(tup, (np.ndarray, Array)):
arrs = jax.vmap(atleast_2d)(tup)
else:
arrs = [atleast_2d(m) for m in tup]
@ -1866,7 +1865,7 @@ row_stack = vstack
@_wraps(np.hstack)
def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(tup, (np.ndarray, ndarray)):
if isinstance(tup, (np.ndarray, Array)):
arrs = jax.vmap(atleast_1d)(tup)
arr0_ndim = arrs.ndim - 1
else:
@ -1878,7 +1877,7 @@ def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
@_wraps(np.dstack)
def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(tup, (np.ndarray, ndarray)):
if isinstance(tup, (np.ndarray, Array)):
arrs = jax.vmap(atleast_3d)(tup)
else:
arrs = [atleast_3d(m) for m in tup]
@ -1887,7 +1886,7 @@ def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
@_wraps(np.column_stack)
def column_stack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]]) -> Array:
if isinstance(tup, (np.ndarray, ndarray)):
if isinstance(tup, (np.ndarray, Array)):
arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup
else:
arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)]
@ -2926,7 +2925,7 @@ def insert(arr, obj, values, axis=None):
raise ValueError("jnp.insert(): obj must be a slice, a one-dimensional "
f"array, or a scalar; got {obj}")
if not np.issubdtype(indices.dtype, np.integer):
if indices.size == 0 and not isinstance(obj, ndarray):
if indices.size == 0 and not isinstance(obj, Array):
indices = indices.astype(int)
else:
# Note: np.insert allows boolean inputs but the behavior is deprecated.
@ -4089,7 +4088,7 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
advanced_pairs = (
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
if isscalar(e) or isinstance(e, (Sequence, ndarray, np.ndarray)))
if isscalar(e) or isinstance(e, (Sequence, Array, np.ndarray)))
if normalize_indices:
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
for e, i, j in advanced_pairs)
@ -4289,7 +4288,7 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
def _should_unpack_list_index(x):
"""Helper for _eliminate_deprecated_list_indexing."""
return (isinstance(x, (np.ndarray, ndarray)) and np.ndim(x) != 0
return (isinstance(x, (np.ndarray, Array)) and np.ndim(x) != 0
or isinstance(x, (Sequence, slice))
or x is Ellipsis or x is None)
@ -4298,7 +4297,7 @@ def _eliminate_deprecated_list_indexing(idx):
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
# objects]". Detects this and raises a TypeError.
if not isinstance(idx, tuple):
if isinstance(idx, Sequence) and not isinstance(idx, (ndarray, np.ndarray)):
if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray)):
# As of numpy 1.16, some non-tuple sequences of indices result in a warning, while
# others are converted to arrays, based on a set of somewhat convoluted heuristics
# (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343)
@ -4393,7 +4392,7 @@ def _is_int_arraylike(x):
def _is_scalar(x):
"""Checks if a Python or NumPy scalar."""
return np.isscalar(x) or (isinstance(x, (np.ndarray, ndarray))
return np.isscalar(x) or (isinstance(x, (np.ndarray, Array))
and np.ndim(x) == 0)
def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
@ -5105,7 +5104,7 @@ def put(*args, **kwargs):
# functions, which can themselves handle instances from any of these classes.
_scalar_types = (int, float, complex, np.generic)
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, ndarray)
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array)
_rejected_binop_types = (list, tuple, set, dict)
def _defer_to_unrecognized_arg(opchar, binary_op, swap=False):
@ -5327,7 +5326,7 @@ class _IndexUpdateHelper:
def __repr__(self):
return f"_IndexUpdateHelper({repr(self.array)})"
ndarray.at.__doc__ = _IndexUpdateHelper.__doc__
Array.at.__doc__ = _IndexUpdateHelper.__doc__
_power_fn = power
_divide_fn = divide

View File

@ -1,17 +0,0 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['ndarray']
from jax._src.typing import Array as ndarray

View File

@ -25,7 +25,6 @@ from jax import lax
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.util import (
_broadcast_to, _check_arraylike, _complex_elem_type,
_promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps)
@ -44,7 +43,7 @@ Axis = Union[None, int, Sequence[int]]
def _asarray(a: ArrayLike) -> Array:
# simplified version of jnp.asarray() for local use.
return a if isinstance(a, ndarray) else api.device_put(a)
return a if isinstance(a, Array) else api.device_put(a)
def _isscalar(element: Any) -> bool:
if hasattr(element, '__jax_array__'):
@ -92,7 +91,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
f"where mask one has to specify 'initial'")
a = a if isinstance(a, ndarray) else _asarray(a)
a = a if isinstance(a, Array) else _asarray(a)
a = preproc(a) if preproc else a
pos_dims, dims = _reduction_dims(a, axis)

View File

@ -25,7 +25,6 @@ from jax._src import api
from jax._src import core
from jax._src.config import config
from jax._src.lax import lax
from jax._src.numpy.ndarray import ndarray
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
@ -324,7 +323,7 @@ def _complex_elem_type(dtype: DTypeLike) -> DType:
def _arraylike(x: ArrayLike) -> bool:
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
return (isinstance(x, np.ndarray) or isinstance(x, Array) or
hasattr(x, '__jax_array__') or np.isscalar(x))
@ -393,7 +392,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape) # type: ignore[union-attr]
_check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, ndarray) else _asarray(arr)
arr = arr if isinstance(arr, Array) else _asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
shape = (shape,)
shape = core.canonicalize_shape(shape) # check that shape is concrete

View File

@ -20,6 +20,8 @@ from jax.numpy import linalg as linalg
from jax._src.device_array import DeviceArray as DeviceArray
from jax._src.basearray import Array as ndarray
from jax._src.numpy.lax_numpy import (
ComplexWarning as ComplexWarning,
NINF as NINF,
@ -178,7 +180,6 @@ from jax._src.numpy.lax_numpy import (
nanmedian as nanmedian,
nanpercentile as nanpercentile,
nanquantile as nanquantile,
ndarray as ndarray,
ndim as ndim,
newaxis as newaxis,
nonzero as nonzero,