mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove internal ndarray type name. Use Array throughout.
jax.numpy.ndarray remains an exported alias for jax.Array. PiperOrigin-RevId: 513046188
This commit is contained in:
parent
52a7701dda
commit
a4412e2715
@ -57,7 +57,6 @@ from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.numpy.reductions import ( # noqa: F401
|
||||
_ensure_optional_axes, _reduction_dims,
|
||||
alltrue, amin, amax, any, all, average, count_nonzero, cumsum, cumprod, cumproduct,
|
||||
@ -288,7 +287,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
|
||||
>>> _convert_and_clip_integer(val, 'int32')
|
||||
Array(2147483647, dtype=int32)
|
||||
"""
|
||||
val = val if isinstance(val, ndarray) else asarray(val)
|
||||
val = val if isinstance(val, Array) else asarray(val)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)):
|
||||
raise TypeError("_convert_and_clip_integer only accepts integer dtypes.")
|
||||
@ -1206,7 +1205,7 @@ def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
|
||||
for i_s in indices_or_sections], np.int64)
|
||||
split_indices = np.concatenate([[np.int64(0)], indices_or_sections,
|
||||
[np.int64(size)]])
|
||||
elif (isinstance(indices_or_sections, (np.ndarray, ndarray)) and
|
||||
elif (isinstance(indices_or_sections, (np.ndarray, Array)) and
|
||||
indices_or_sections.ndim > 0):
|
||||
indices_or_sections = np.array(
|
||||
[core.concrete_or_error(np.int64, i_s, f"in jax.numpy.{op} argument 1")
|
||||
@ -1777,7 +1776,7 @@ def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
|
||||
raise ValueError("Need at least one array to stack.")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
|
||||
if isinstance(arrays, (np.ndarray, ndarray)):
|
||||
if isinstance(arrays, (np.ndarray, Array)):
|
||||
axis = _canonicalize_axis(axis, arrays.ndim)
|
||||
return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
|
||||
else:
|
||||
@ -1826,7 +1825,7 @@ def _concatenate_array(arr: ArrayLike, axis: Optional[int],
|
||||
@_wraps(np.concatenate)
|
||||
def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
|
||||
axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
if isinstance(arrays, (np.ndarray, ndarray)):
|
||||
if isinstance(arrays, (np.ndarray, Array)):
|
||||
return _concatenate_array(arrays, axis, dtype=dtype)
|
||||
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
|
||||
if not len(arrays):
|
||||
@ -1855,7 +1854,7 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
|
||||
@_wraps(np.vstack)
|
||||
def vstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
|
||||
dtype: Optional[DTypeLike] = None) -> Array:
|
||||
if isinstance(tup, (np.ndarray, ndarray)):
|
||||
if isinstance(tup, (np.ndarray, Array)):
|
||||
arrs = jax.vmap(atleast_2d)(tup)
|
||||
else:
|
||||
arrs = [atleast_2d(m) for m in tup]
|
||||
@ -1866,7 +1865,7 @@ row_stack = vstack
|
||||
@_wraps(np.hstack)
|
||||
def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
|
||||
dtype: Optional[DTypeLike] = None) -> Array:
|
||||
if isinstance(tup, (np.ndarray, ndarray)):
|
||||
if isinstance(tup, (np.ndarray, Array)):
|
||||
arrs = jax.vmap(atleast_1d)(tup)
|
||||
arr0_ndim = arrs.ndim - 1
|
||||
else:
|
||||
@ -1878,7 +1877,7 @@ def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
|
||||
@_wraps(np.dstack)
|
||||
def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
|
||||
dtype: Optional[DTypeLike] = None) -> Array:
|
||||
if isinstance(tup, (np.ndarray, ndarray)):
|
||||
if isinstance(tup, (np.ndarray, Array)):
|
||||
arrs = jax.vmap(atleast_3d)(tup)
|
||||
else:
|
||||
arrs = [atleast_3d(m) for m in tup]
|
||||
@ -1887,7 +1886,7 @@ def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
|
||||
|
||||
@_wraps(np.column_stack)
|
||||
def column_stack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]]) -> Array:
|
||||
if isinstance(tup, (np.ndarray, ndarray)):
|
||||
if isinstance(tup, (np.ndarray, Array)):
|
||||
arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup
|
||||
else:
|
||||
arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)]
|
||||
@ -2926,7 +2925,7 @@ def insert(arr, obj, values, axis=None):
|
||||
raise ValueError("jnp.insert(): obj must be a slice, a one-dimensional "
|
||||
f"array, or a scalar; got {obj}")
|
||||
if not np.issubdtype(indices.dtype, np.integer):
|
||||
if indices.size == 0 and not isinstance(obj, ndarray):
|
||||
if indices.size == 0 and not isinstance(obj, Array):
|
||||
indices = indices.astype(int)
|
||||
else:
|
||||
# Note: np.insert allows boolean inputs but the behavior is deprecated.
|
||||
@ -4089,7 +4088,7 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
|
||||
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
|
||||
advanced_pairs = (
|
||||
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
|
||||
if isscalar(e) or isinstance(e, (Sequence, ndarray, np.ndarray)))
|
||||
if isscalar(e) or isinstance(e, (Sequence, Array, np.ndarray)))
|
||||
if normalize_indices:
|
||||
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
|
||||
for e, i, j in advanced_pairs)
|
||||
@ -4289,7 +4288,7 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
|
||||
|
||||
def _should_unpack_list_index(x):
|
||||
"""Helper for _eliminate_deprecated_list_indexing."""
|
||||
return (isinstance(x, (np.ndarray, ndarray)) and np.ndim(x) != 0
|
||||
return (isinstance(x, (np.ndarray, Array)) and np.ndim(x) != 0
|
||||
or isinstance(x, (Sequence, slice))
|
||||
or x is Ellipsis or x is None)
|
||||
|
||||
@ -4298,7 +4297,7 @@ def _eliminate_deprecated_list_indexing(idx):
|
||||
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
|
||||
# objects]". Detects this and raises a TypeError.
|
||||
if not isinstance(idx, tuple):
|
||||
if isinstance(idx, Sequence) and not isinstance(idx, (ndarray, np.ndarray)):
|
||||
if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray)):
|
||||
# As of numpy 1.16, some non-tuple sequences of indices result in a warning, while
|
||||
# others are converted to arrays, based on a set of somewhat convoluted heuristics
|
||||
# (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343)
|
||||
@ -4393,7 +4392,7 @@ def _is_int_arraylike(x):
|
||||
|
||||
def _is_scalar(x):
|
||||
"""Checks if a Python or NumPy scalar."""
|
||||
return np.isscalar(x) or (isinstance(x, (np.ndarray, ndarray))
|
||||
return np.isscalar(x) or (isinstance(x, (np.ndarray, Array))
|
||||
and np.ndim(x) == 0)
|
||||
|
||||
def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
|
||||
@ -5105,7 +5104,7 @@ def put(*args, **kwargs):
|
||||
# functions, which can themselves handle instances from any of these classes.
|
||||
|
||||
_scalar_types = (int, float, complex, np.generic)
|
||||
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, ndarray)
|
||||
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array)
|
||||
_rejected_binop_types = (list, tuple, set, dict)
|
||||
|
||||
def _defer_to_unrecognized_arg(opchar, binary_op, swap=False):
|
||||
@ -5327,7 +5326,7 @@ class _IndexUpdateHelper:
|
||||
|
||||
def __repr__(self):
|
||||
return f"_IndexUpdateHelper({repr(self.array)})"
|
||||
ndarray.at.__doc__ = _IndexUpdateHelper.__doc__
|
||||
Array.at.__doc__ = _IndexUpdateHelper.__doc__
|
||||
|
||||
_power_fn = power
|
||||
_divide_fn = divide
|
||||
|
@ -1,17 +0,0 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['ndarray']
|
||||
|
||||
from jax._src.typing import Array as ndarray
|
@ -25,7 +25,6 @@ from jax import lax
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.numpy.util import (
|
||||
_broadcast_to, _check_arraylike, _complex_elem_type,
|
||||
_promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps)
|
||||
@ -44,7 +43,7 @@ Axis = Union[None, int, Sequence[int]]
|
||||
|
||||
def _asarray(a: ArrayLike) -> Array:
|
||||
# simplified version of jnp.asarray() for local use.
|
||||
return a if isinstance(a, ndarray) else api.device_put(a)
|
||||
return a if isinstance(a, Array) else api.device_put(a)
|
||||
|
||||
def _isscalar(element: Any) -> bool:
|
||||
if hasattr(element, '__jax_array__'):
|
||||
@ -92,7 +91,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
|
||||
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
|
||||
f"where mask one has to specify 'initial'")
|
||||
|
||||
a = a if isinstance(a, ndarray) else _asarray(a)
|
||||
a = a if isinstance(a, Array) else _asarray(a)
|
||||
a = preproc(a) if preproc else a
|
||||
pos_dims, dims = _reduction_dims(a, axis)
|
||||
|
||||
|
@ -25,7 +25,6 @@ from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src.config import config
|
||||
from jax._src.lax import lax
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
|
||||
|
||||
@ -324,7 +323,7 @@ def _complex_elem_type(dtype: DTypeLike) -> DType:
|
||||
|
||||
|
||||
def _arraylike(x: ArrayLike) -> bool:
|
||||
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
|
||||
return (isinstance(x, np.ndarray) or isinstance(x, Array) or
|
||||
hasattr(x, '__jax_array__') or np.isscalar(x))
|
||||
|
||||
|
||||
@ -393,7 +392,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
|
||||
if hasattr(arr, "broadcast_to"):
|
||||
return arr.broadcast_to(shape) # type: ignore[union-attr]
|
||||
_check_arraylike("broadcast_to", arr)
|
||||
arr = arr if isinstance(arr, ndarray) else _asarray(arr)
|
||||
arr = arr if isinstance(arr, Array) else _asarray(arr)
|
||||
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
|
||||
shape = (shape,)
|
||||
shape = core.canonicalize_shape(shape) # check that shape is concrete
|
||||
|
@ -20,6 +20,8 @@ from jax.numpy import linalg as linalg
|
||||
|
||||
from jax._src.device_array import DeviceArray as DeviceArray
|
||||
|
||||
from jax._src.basearray import Array as ndarray
|
||||
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
ComplexWarning as ComplexWarning,
|
||||
NINF as NINF,
|
||||
@ -178,7 +180,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
nanmedian as nanmedian,
|
||||
nanpercentile as nanpercentile,
|
||||
nanquantile as nanquantile,
|
||||
ndarray as ndarray,
|
||||
ndim as ndim,
|
||||
newaxis as newaxis,
|
||||
nonzero as nonzero,
|
||||
|
Loading…
x
Reference in New Issue
Block a user