refactor: import numpy objects directly in jax.numpy

This commit is contained in:
Jake VanderPlas 2025-02-14 11:22:18 -08:00
parent 36d7f8530b
commit 33b989ac9e
6 changed files with 137 additions and 158 deletions

View File

@ -61,7 +61,7 @@ def _mask(x, dims, alternative=0):
Replaces values outside those dimensions with `alternative`. `alternative` is
broadcast with `x`.
"""
assert jnp.ndim(x) == len(dims)
assert np.ndim(x) == len(dims)
mask = None
for i, d in enumerate(dims):
if d is not None:
@ -145,7 +145,7 @@ def _projector_subspace(P, H, n, rank, maxiter=2, swap=False):
N, _ = P.shape
negative_column_norms = -jnp_linalg.norm(P, axis=1)
# `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN.
negative_column_norms = _mask(negative_column_norms, (n,), jnp.nan)
negative_column_norms = _mask(negative_column_norms, (n,), np.nan)
sort_idxs = jnp.argsort(negative_column_norms)
X = P[:, sort_idxs]
# X = X[:, :rank]
@ -397,7 +397,7 @@ def _eigh_work(H, n, termination_size, subset_by_index):
def default_case(agenda, blocks, eigenvectors):
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
# TODO: Improve this?
split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan))
split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), np.nan))
H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(
H, b, split_point, V0=V)
@ -564,7 +564,7 @@ def eigh(
eig_vals, eig_vecs = _eigh_work(
H, n, termination_size=termination_size, subset_by_index=subset_by_index
)
eig_vals = _mask(ufuncs.real(eig_vals), (n,), jnp.nan)
eig_vals = _mask(ufuncs.real(eig_vals), (n,), np.nan)
if sort_eigenvalues or compute_slice:
sort_idxs = jnp.argsort(eig_vals)
if compute_slice:

View File

@ -82,18 +82,8 @@ for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
else:
break
newaxis = None
T = TypeVar('T')
# NumPy constants
pi = np.pi
e = np.e
euler_gamma = np.euler_gamma
inf = np.inf
nan = np.nan
# Wrappers for NumPy printoptions
def get_printoptions():
@ -169,9 +159,6 @@ def iscomplexobj(x: Any) -> bool:
typ = asarray(x).dtype.type
return issubdtype(typ, np.complexfloating)
shape = _shape = np.shape
ndim = _ndim = np.ndim
size = np.size
def _dtype(x: Any) -> DType:
return dtypes.dtype(x, canonicalize=True)
@ -180,19 +167,11 @@ def _dtype(x: Any) -> DType:
iinfo = dtypes.iinfo
finfo = dtypes.finfo
dtype = np.dtype
can_cast = dtypes.can_cast
promote_types = dtypes.promote_types
ComplexWarning = NumpyComplexWarning
# Numpy functions
array_str = np.array_str
array_repr = np.array_repr
save = np.save
savez = np.savez
_lax_const = lax_internal._const
@ -534,8 +513,6 @@ def isscalar(element: Any) -> bool:
return asarray(element).ndim == 0
return False
iterable = np.iterable
@export
def result_type(*args: Any) -> DType:
@ -621,7 +598,7 @@ def trunc(x: ArrayLike) -> Array:
@partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type'])
def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike,
preferred_element_type: DTypeLike | None = None) -> Array:
if ndim(x) != 1 or ndim(y) != 1:
if np.ndim(x) != 1 or np.ndim(y) != 1:
raise ValueError(f"{op}() only support 1-dimensional inputs.")
if preferred_element_type is None:
# if unspecified, promote to inexact following NumPy's default for convolutions.
@ -856,7 +833,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
util.check_arraylike("histogram_bin_edges", a, bins)
arr = asarray(a)
dtype = dtypes.to_inexact_dtype(arr.dtype)
if _ndim(bins) == 1:
if np.ndim(bins) == 1:
return asarray(bins, dtype=dtype)
bins_int = core.concrete_or_error(operator.index, bins,
@ -864,7 +841,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
if range is None:
range = [arr.min(), arr.max()]
range = asarray(range, dtype=dtype)
if shape(range) != (2,):
if np.shape(range) != (2,):
raise ValueError(f"`range` must be either None or a sequence of scalars, got {range}")
range = (where(reductions.ptp(range) == 0, range[0] - 0.5, range[0]),
where(reductions.ptp(range) == 0, range[1] + 0.5, range[1]))
@ -940,7 +917,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10,
weights = ones_like(a)
else:
util.check_arraylike("histogram", a, bins, weights)
if shape(a) != shape(weights):
if np.shape(a) != np.shape(weights):
raise ValueError("weights should have the same shape as a.")
a, weights = util.promote_dtypes_inexact(a, weights)
@ -1105,13 +1082,13 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
sample, = util.promote_dtypes_inexact(sample)
else:
util.check_arraylike("histogramdd", sample, weights)
if shape(weights) != shape(sample)[:1]:
if np.shape(weights) != np.shape(sample)[:1]:
raise ValueError("should have one weight for each sample.")
sample, weights = util.promote_dtypes_inexact(sample, weights)
N, D = shape(sample)
N, D = np.shape(sample)
if range is not None and (
len(range) != D or any(r is not None and shape(r)[0] != 2 for r in range)): # type: ignore[arg-type]
len(range) != D or any(r is not None and np.shape(r)[0] != 2 for r in range)): # type: ignore[arg-type]
raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence "
f"of {D} pairs or Nones; got {range=}")
@ -1228,8 +1205,8 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
[2, 4]], dtype=int32)
"""
util.check_arraylike("transpose", a)
axes_ = list(range(ndim(a))[::-1]) if axes is None else axes
axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_]
axes_ = list(range(np.ndim(a))[::-1]) if axes is None else axes
axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_]
return lax.transpose(a, axes_)
@ -1383,8 +1360,8 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
f"two, but got first argument of shape {np.shape(m)}, "
f"which has ndim {np.ndim(m)}")
ax1, ax2 = axes
ax1 = _canonicalize_axis(ax1, ndim(m))
ax2 = _canonicalize_axis(ax2, ndim(m))
ax1 = _canonicalize_axis(ax1, np.ndim(m))
ax2 = _canonicalize_axis(ax2, np.ndim(m))
if ax1 == ax2:
raise ValueError("Axes must be different") # same as numpy error
k = k % 4
@ -1393,7 +1370,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
elif k == 2:
return flip(flip(m, ax1), ax2)
else:
perm = list(range(ndim(m)))
perm = list(range(np.ndim(m)))
perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
if k == 1:
return transpose(flip(m, ax2), perm)
@ -1464,9 +1441,9 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
@partial(jit, static_argnames=('axis',))
def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
if axis is None:
return lax.rev(m, list(range(len(shape(m)))))
return lax.rev(m, list(range(len(np.shape(m)))))
axis = _ensure_index_tuple(axis)
return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis])
return lax.rev(m, [_canonicalize_axis(ax, np.ndim(m)) for ax in axis])
@export
@ -1617,7 +1594,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array:
im = ufuncs.imag(z)
dtype = _dtype(re)
if not issubdtype(dtype, np.inexact) or (
issubdtype(_dtype(z), np.floating) and ndim(z) == 0):
issubdtype(_dtype(z), np.floating) and np.ndim(z) == 0):
dtype = dtypes.canonicalize_dtype(dtypes.float_)
re = lax.convert_element_type(re, dtype)
im = lax.convert_element_type(im, dtype)
@ -1704,7 +1681,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
combined: list[Array] = []
if prepend is not None:
prepend = util.ensure_arraylike("diff", prepend)
if not ndim(prepend):
if not np.ndim(prepend):
shape = list(arr.shape)
shape[axis] = 1
prepend = broadcast_to(prepend, tuple(shape))
@ -1714,7 +1691,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
if append is not None:
append = util.ensure_arraylike("diff", append)
if not ndim(append):
if not np.ndim(append):
shape = list(arr.shape)
shape[axis] = 1
append = broadcast_to(append, tuple(shape))
@ -1878,12 +1855,12 @@ def gradient(
upper_edge = sliced(1, 2) - sliced(0, 1)
lower_edge = sliced(-1, None) - sliced(-2, -1)
if ndim(h) == 0:
if np.ndim(h) == 0:
inner = (sliced(2, None) - sliced(None, -2)) * 0.5 / h
lower_edge /= h
upper_edge /= h
elif ndim(h) == 1:
elif np.ndim(h) == 1:
if len(h) != a.shape[axis]:
raise ValueError(
"Spacing arrays must have the same length as the "
@ -2112,7 +2089,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array:
util.check_arraylike("ravel", a)
if order == "K":
raise NotImplementedError("Ravel not implemented for order='K'.")
return reshape(a, (size(a),), order)
return reshape(a, (np.size(a),), order)
@export
@ -2259,7 +2236,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
# TODO: Consider warning here since shape is supposed to be a sequence, so
# this should not happen.
shape = [shape]
if any(ndim(s) != 0 for s in shape):
if any(np.ndim(s) != 0 for s in shape):
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
out_indices: list[ArrayLike] = [0] * len(shape)
for i, s in reversed(list(enumerate(shape))):
@ -2385,7 +2362,7 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
@partial(jit, static_argnames=('axis',), inline=True)
def _squeeze(a: Array, axis: tuple[int, ...]) -> Array:
if axis is None:
a_shape = shape(a)
a_shape = np.shape(a)
if not core.is_constant_shape(a_shape):
# We do not even know the rank of the output if the input shape is not known
raise ValueError("jnp.squeeze with axis=None is not supported with shape polymorphism")
@ -2507,7 +2484,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
(2, 5, 4, 3)
"""
util.check_arraylike("swapaxes", a)
perm = np.arange(ndim(a))
perm = np.arange(np.ndim(a))
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
return lax.transpose(a, list(perm))
@ -2567,12 +2544,12 @@ def moveaxis(a: ArrayLike, source: int | Sequence[int],
@partial(jit, static_argnames=('source', 'destination'), inline=True)
def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -> Array:
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
source = tuple(_canonicalize_axis(i, np.ndim(a)) for i in source)
destination = tuple(_canonicalize_axis(i, np.ndim(a)) for i in destination)
if len(source) != len(destination):
raise ValueError("Inconsistent number of elements: {} vs {}"
.format(len(source), len(destination)))
perm = [i for i in range(ndim(a)) if i not in source]
perm = [i for i in range(np.ndim(a)) if i not in source]
for dest, src in sorted(zip(destination, source)):
perm.insert(dest, src)
return lax.transpose(a, perm)
@ -2666,7 +2643,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
right: ArrayLike | str | None = None,
period: ArrayLike | None = None) -> Array:
util.check_arraylike("interp", x, xp, fp)
if shape(xp) != shape(fp) or ndim(xp) != 1:
if np.shape(xp) != np.shape(fp) or np.ndim(xp) != 1:
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
x_arr, xp_arr = util.promote_dtypes_inexact(x, xp)
fp_arr, = util.promote_dtypes_inexact(fp)
@ -2691,7 +2668,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
raise ValueError("jnp.interp: complex x values not supported.")
if period is not None:
if ndim(period) != 0:
if np.ndim(period) != 0:
raise ValueError(f"period must be a scalar; got {period}")
period = ufuncs.abs(period)
x_arr = x_arr % period
@ -3018,7 +2995,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None,
x = lax.convert_element_type(x, 'int32')
if not issubdtype(_dtype(x), np.integer):
raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
if ndim(x) != 1:
if np.ndim(x) != 1:
raise ValueError("only 1-dimensional input supported.")
minlength = core.concrete_or_error(operator.index, minlength,
"The error occurred because of argument 'minlength' of jnp.bincount.")
@ -3032,7 +3009,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None,
"The error occurred because of argument 'length' of jnp.bincount.")
if weights is None:
weights = np.array(1, dtype=dtypes.int_)
elif shape(x) != shape(weights):
elif np.shape(x) != np.shape(weights):
raise ValueError("shape of weights must match shape of x.")
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights, mode='drop')
@ -3789,7 +3766,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
"""
arr = util.ensure_arraylike("nonzero", a)
del a
if ndim(arr) == 0:
if np.ndim(arr) == 0:
raise ValueError("Calling nonzero on 0d arrays is not allowed. "
"Use jnp.atleast_1d(scalar).nonzero() instead.")
mask = arr if arr.dtype == bool else (arr != 0)
@ -3805,7 +3782,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape))
if fill_value is not None:
fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,)
if any(_shape(val) != () for val in fill_value_tup):
if any(np.shape(val) != () for val in fill_value_tup):
raise ValueError(f"fill_value must be a scalar or a tuple of length {arr.ndim}; got {fill_value}")
fill_mask = arange(calculated_size) >= mask.sum()
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out))
@ -3861,7 +3838,7 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None,
@export
@partial(jit, static_argnames=('axis',))
def unwrap(p: ArrayLike, discont: ArrayLike | None = None,
axis: int = -1, period: ArrayLike = 2 * pi) -> Array:
axis: int = -1, period: ArrayLike = 2 * np.pi) -> Array:
"""Unwrap a periodic signal.
JAX implementation of :func:`numpy.unwrap`.
@ -3997,10 +3974,10 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str):
def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array:
nd = ndim(array)
nd = np.ndim(array)
constant_values = lax_internal._convert_element_type(
constant_values, array.dtype, dtypes.is_weakly_typed(array))
constant_values_nd = ndim(constant_values)
constant_values_nd = np.ndim(constant_values)
if constant_values_nd == 0:
widths = [(low, high, 0) for (low, high) in pad_width]
@ -4033,7 +4010,7 @@ def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array
def _pad_wrap(array: Array, pad_width: PadValue[int]) -> Array:
for i in range(ndim(array)):
for i in range(np.ndim(array)):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], "wrap")
continue
@ -4056,7 +4033,7 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int],
assert mode in ("symmetric", "reflect")
assert reflect_type in ("even", "odd")
for i in range(ndim(array)):
for i in range(np.ndim(array)):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], mode)
continue
@ -4121,7 +4098,7 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int],
def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array:
nd = ndim(array)
nd = np.ndim(array)
for i in range(nd):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], "edge")
@ -4142,7 +4119,7 @@ def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array:
def _pad_linear_ramp(array: Array, pad_width: PadValue[int],
end_values: PadValue[ArrayLike]) -> Array:
for axis in range(ndim(array)):
for axis in range(np.ndim(array)):
edge_before = lax.slice_in_dim(array, 0, 1, axis=axis)
edge_after = lax.slice_in_dim(array, -1, None, axis=axis)
ramp_before = linspace(
@ -4176,7 +4153,7 @@ def _pad_linear_ramp(array: Array, pad_width: PadValue[int],
def _pad_stats(array: Array, pad_width: PadValue[int],
stat_length: PadValue[int] | None,
stat_func: PadStatFunc) -> Array:
nd = ndim(array)
nd = np.ndim(array)
for i in range(nd):
if stat_length is None:
stat_before = stat_func(array, axis=i, keepdims=True)
@ -4215,7 +4192,7 @@ def _pad_stats(array: Array, pad_width: PadValue[int],
def _pad_empty(array: Array, pad_width: PadValue[int]) -> Array:
# Note: jax.numpy.empty = jax.numpy.zeros
for i in range(ndim(array)):
for i in range(np.ndim(array)):
shape_before = array.shape[:i] + (pad_width[i][0],) + array.shape[i + 1:]
pad_before = empty_like(array, shape=shape_before)
@ -4226,9 +4203,9 @@ def _pad_empty(array: Array, pad_width: PadValue[int]) -> Array:
def _pad_func(array: Array, pad_width: PadValue[int], func: Callable[..., Any], **kwargs) -> Array:
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width")
padded = _pad_constant(array, pad_width, asarray(0))
for axis in range(ndim(padded)):
for axis in range(np.ndim(padded)):
padded = apply_along_axis(func, axis, padded, pad_width[axis], axis, kwargs)
return padded
@ -4238,7 +4215,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str,
constant_values: ArrayLike, stat_length: PadValueLike[int],
end_values: PadValueLike[ArrayLike], reflect_type: str):
array = asarray(array)
nd = ndim(array)
nd = np.ndim(array)
if nd == 0:
return array
@ -4406,7 +4383,7 @@ def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray],
"""
util.check_arraylike("pad", array)
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width")
if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1])
for p in pad_width):
raise TypeError('`pad_width` must be of integral type.')
@ -4501,11 +4478,11 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
else:
util.check_arraylike("stack", *arrays)
shape0 = shape(arrays[0])
shape0 = np.shape(arrays[0])
axis = _canonicalize_axis(axis, len(shape0) + 1)
new_arrays = []
for a in arrays:
if shape(a) != shape0:
if np.shape(a) != shape0:
raise ValueError("All input arrays must have the same shape.")
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis, dtype=dtype)
@ -4598,7 +4575,7 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
reps_tup = tuple(reps) # type: ignore[arg-type]
reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
for rep in reps_tup)
A_shape = (1,) * (len(reps_tup) - ndim(A)) + shape(A)
A_shape = (1,) * (len(reps_tup) - np.ndim(A)) + np.shape(A)
reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
[k for pair in zip(reps_tup, A_shape) for k in pair])
@ -4667,9 +4644,9 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
raise ValueError("Need at least one array to concatenate.")
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
if ndim(arrays[0]) == 0:
if np.ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
axis = _canonicalize_axis(axis, ndim(arrays[0]))
axis = _canonicalize_axis(axis, np.ndim(arrays[0]))
if dtype is None:
arrays_out = util.promote_dtypes(*arrays)
else:
@ -5074,7 +5051,7 @@ def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike],
def _atleast_nd(x: ArrayLike, n: int) -> Array:
m = ndim(x)
m = np.ndim(x)
return lax.broadcast(x, (1,) * (n - m)) if m < n else asarray(x)
def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]:
@ -5087,7 +5064,7 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]:
xs_tup, depths = unzip2([_block(x) for x in xs])
if any(d != depths[0] for d in depths[1:]):
raise ValueError("Mismatched list depths in jax.numpy.block")
rank = max(depths[0], max(ndim(x) for x in xs_tup))
rank = max(depths[0], max(np.ndim(x) for x in xs_tup))
xs_tup = tuple(_atleast_nd(x, rank) for x in xs_tup)
return concatenate(xs_tup, axis=-depths[0]), depths[0] + 1
else:
@ -5589,8 +5566,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
raise TypeError(f"Unexpected input type for array: {type(object)}")
out_array: Array = lax_internal._convert_element_type(
out, dtype, weak_type=weak_type, sharding=sharding)
if ndmin > ndim(out_array):
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
if ndmin > np.ndim(out_array):
out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array)))
return out_array
@ -5839,7 +5816,7 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
Array(True, dtype=bool)
"""
a1, a2 = asarray(a1), asarray(a2)
if shape(a1) != shape(a2):
if np.shape(a1) != np.shape(a2):
return array(False, dtype=bool)
eq = asarray(a1 == a2)
if equal_nan:
@ -6519,7 +6496,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
start = start.astype(computation_dtype)
stop = stop.astype(computation_dtype)
bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop)))
broadcast_start = broadcast_to(start, bounds_shape)
broadcast_stop = broadcast_to(stop, bounds_shape)
axis = len(bounds_shape) + axis + 1 if axis < 0 else axis
@ -6542,12 +6519,12 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
_canonicalize_axis(axis, out.ndim))
elif num == 1:
delta = asarray(nan if endpoint else stop - start, dtype=computation_dtype)
delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype)
out = reshape(broadcast_start, bounds_shape)
else: # num == 0 degenerate case, match numpy behavior
empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop)))
empty_shape.insert(axis, 0)
delta = asarray(nan, dtype=computation_dtype)
delta = asarray(np.nan, dtype=computation_dtype)
out = reshape(array([], dtype=dtype), empty_shape)
if issubdtype(dtype, np.integer) and not issubdtype(out.dtype, np.integer):
@ -7053,7 +7030,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
"value to `total_repeat_length`.")
# Fast path for when repeats is a scalar.
if np.ndim(repeats) == 0 and ndim(arr) != 0:
if np.ndim(repeats) == 0 and np.ndim(arr) != 0:
input_shape = arr.shape
axis = _canonicalize_axis(axis, len(input_shape))
aux_axis = axis + 1
@ -7076,7 +7053,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
# Special case when a is a scalar.
if arr.ndim == 0:
if shape(repeats) == (1,):
if np.shape(repeats) == (1,):
return full([total_repeat_length], arr)
else:
raise ValueError('`repeat` with a scalar parameter `a` is only '
@ -7279,7 +7256,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array:
[7, 8]]], dtype=int32)
"""
util.check_arraylike("tril", m)
m_shape = shape(m)
m_shape = np.shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.tril must be at least 2D")
N, M = m_shape[-2:]
@ -7346,7 +7323,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array:
[0, 8]]], dtype=int32)
"""
util.check_arraylike("triu", m)
m_shape = shape(m)
m_shape = np.shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.triu must be at least 2D")
N, M = m_shape[-2:]
@ -7406,12 +7383,12 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
if _canonicalize_axis(axis1, ndim(a)) == _canonicalize_axis(axis2, ndim(a)):
if _canonicalize_axis(axis1, np.ndim(a)) == _canonicalize_axis(axis2, np.ndim(a)):
raise ValueError(f"axis1 and axis2 can not be same. axis1={axis1} and axis2={axis2}")
dtypes.check_user_dtype_supported(dtype, "trace")
a_shape = shape(a)
a_shape = np.shape(a)
a = moveaxis(a, (axis1, axis2), (-2, -1))
# Mask out the diagonal and reduce.
@ -7650,7 +7627,7 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
>>> jnp.triu_indices_from(arr, k=-1)
(Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))
"""
arr_shape = shape(arr)
arr_shape = np.shape(arr)
if len(arr_shape) != 2:
raise ValueError("Only 2-D inputs are accepted")
return triu_indices(arr_shape[0], k=k, m=arr_shape[1])
@ -7708,7 +7685,7 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
>>> jnp.tril_indices_from(arr, k=-1)
(Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))
"""
arr_shape = shape(arr)
arr_shape = np.shape(arr)
if len(arr_shape) != 2:
raise ValueError("Only 2-D inputs are accepted")
return tril_indices(arr_shape[0], k=k, m=arr_shape[1])
@ -7863,12 +7840,12 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]:
Array([0, 1], dtype=int32))
"""
util.check_arraylike("diag_indices_from", arr)
nd = ndim(arr)
if not ndim(arr) >= 2:
nd = np.ndim(arr)
if not np.ndim(arr) >= 2:
raise ValueError("input array must be at least 2-d")
s = shape(arr)
if len(set(shape(arr))) != 1:
s = np.shape(arr)
if len(set(np.shape(arr))) != 1:
raise ValueError("All dimensions of input must be of equal length")
return diag_indices(s[0], ndim=nd)
@ -7913,12 +7890,12 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
"""
util.check_arraylike("diagonal", a)
if ndim(a) < 2:
if np.ndim(a) < 2:
raise ValueError("diagonal requires an array of at least two dimensions.")
offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()")
def _default_diag(a):
a_shape = shape(a)
a_shape = np.shape(a)
a = moveaxis(a, (axis1, axis2), (-2, -1))
@ -7932,10 +7909,10 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
# The mosaic lowering rule for diag is only defined for square arrays.
# TODO(mvoz): Add support for offsets.
if shape(a)[0] != shape(a)[1] or ndim(a) != 2 or offset != 0 or _dtype(a) == bool:
if np.shape(a)[0] != np.shape(a)[1] or np.ndim(a) != 2 or offset != 0 or _dtype(a) == bool:
return _default_diag(a)
else:
a_shape_eye = eye(shape(a)[0], dtype=_dtype(a))
a_shape_eye = eye(np.shape(a)[0], dtype=_dtype(a))
def _mosaic_diag(a):
def _sum(x, axis):
@ -8002,7 +7979,7 @@ def diag(v: ArrayLike, k: int = 0) -> Array:
@partial(jit, static_argnames=('k',))
def _diag(v, k):
util.check_arraylike("diag", v)
v_shape = shape(v)
v_shape = np.shape(v)
if len(v_shape) == 1:
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
n = v_shape[0] + abs(k)
@ -8472,7 +8449,7 @@ def apply_along_axis(
Array([ 65, 133, 243], dtype=int32)
"""
util.check_arraylike("apply_along_axis", arr)
num_dims = ndim(arr)
num_dims = np.ndim(arr)
axis = _canonicalize_axis(axis, num_dims)
func = lambda arr: func1d(arr, *args, **kwargs)
for i in range(1, num_dims - axis):
@ -8675,13 +8652,13 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array:
"""
util.check_arraylike("kron", a, b)
a, b = util.promote_dtypes(a, b)
if ndim(a) < ndim(b):
a = expand_dims(a, range(ndim(b) - ndim(a)))
elif ndim(b) < ndim(a):
b = expand_dims(b, range(ndim(a) - ndim(b)))
a_reshaped = expand_dims(a, range(1, 2 * ndim(a), 2))
b_reshaped = expand_dims(b, range(0, 2 * ndim(b), 2))
out_shape = tuple(np.multiply(shape(a), shape(b)))
if np.ndim(a) < np.ndim(b):
a = expand_dims(a, range(np.ndim(b) - np.ndim(a)))
elif np.ndim(b) < np.ndim(a):
b = expand_dims(b, range(np.ndim(a) - np.ndim(b)))
a_reshaped = expand_dims(a, range(1, 2 * np.ndim(a), 2))
b_reshaped = expand_dims(b, range(0, 2 * np.ndim(b), 2))
out_shape = tuple(np.multiply(np.shape(a), np.shape(b)))
return reshape(lax.mul(a_reshaped, b_reshaped), out_shape)
@ -8809,9 +8786,9 @@ def argwhere(
Array([], shape=(0, 0), dtype=int32)
"""
result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value)))
if ndim(a) == 0:
if np.ndim(a) == 0:
return result[:0].reshape(result.shape[0], 0)
return result.reshape(result.shape[0], ndim(a))
return result.reshape(result.shape[0], np.ndim(a))
@export
@ -8859,7 +8836,7 @@ def argmax(a: ArrayLike, axis: int | None = None, out: None = None,
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array:
if axis is None:
dims = list(range(ndim(a)))
dims = list(range(np.ndim(a)))
a = ravel(a)
axis = 0
else:
@ -8915,7 +8892,7 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None,
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array:
if axis is None:
dims = list(range(ndim(a)))
dims = list(range(np.ndim(a)))
a = ravel(a)
axis = 0
else:
@ -8989,7 +8966,7 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False):
if not issubdtype(_dtype(a), np.inexact):
return argmax(a, axis=axis, keepdims=keepdims)
nan_mask = ufuncs.isnan(a)
a = where(nan_mask, -inf, a)
a = where(nan_mask, -np.inf, a)
res = argmax(a, axis=axis, keepdims=keepdims)
return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
@ -9050,7 +9027,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False):
if not issubdtype(_dtype(a), np.inexact):
return argmin(a, axis=axis, keepdims=keepdims)
nan_mask = ufuncs.isnan(a)
a = where(nan_mask, inf, a)
a = where(nan_mask, np.inf, a)
res = argmin(a, axis=axis, keepdims=keepdims)
return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
@ -9191,7 +9168,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array:
"""
util.check_arraylike("rollaxis", a)
start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()")
a_ndim = ndim(a)
a_ndim = np.ndim(a)
axis = _canonicalize_axis(axis, a_ndim)
if not (-a_ndim <= start <= a_ndim):
raise ValueError(f"{start=} must satisfy {-a_ndim}<=start<={a_ndim}")
@ -9764,9 +9741,9 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True,
w: Array | None = None
if fweights is not None:
fweights = util.ensure_arraylike("cov", fweights)
if ndim(fweights) > 1:
if np.ndim(fweights) > 1:
raise RuntimeError("cannot handle multidimensional fweights")
if shape(fweights)[0] != X.shape[1]:
if np.shape(fweights)[0] != X.shape[1]:
raise RuntimeError("incompatible numbers of samples and fweights")
if not issubdtype(_dtype(fweights), np.integer):
raise TypeError("fweights must be integer.")
@ -9774,9 +9751,9 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True,
w = abs(fweights)
if aweights is not None:
aweights = util.ensure_arraylike("cov", aweights)
if ndim(aweights) > 1:
if np.ndim(aweights) > 1:
raise RuntimeError("cannot handle multidimensional aweights")
if shape(aweights)[0] != X.shape[1]:
if np.shape(aweights)[0] != X.shape[1]:
raise RuntimeError("incompatible numbers of samples and aweights")
# Ensure positive aweights: note that numpy raises an error for negative aweights.
aweights = abs(aweights)
@ -9877,7 +9854,7 @@ def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> A
"""
util.check_arraylike("corrcoef", x)
c = cov(x, y, rowvar)
if len(shape(c)) == 0:
if len(np.shape(c)) == 0:
# scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise
return ufuncs.divide(c, c)
d = diag(c)
@ -10002,7 +9979,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
raise ValueError(
f"{method!r} is an invalid value for keyword 'method'. "
"Expected one of ['sort', 'scan', 'scan_unrolled', 'compare_all'].")
if ndim(a) != 1:
if np.ndim(a) != 1:
raise ValueError("a should be 1-dimensional")
a, v = util.promote_dtypes(a, v)
if sorter is not None:

View File

@ -478,7 +478,7 @@ def _slogdet_lu(a: Array) -> tuple[Array, Array]:
jnp.array(0, dtype=dtype),
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
logdet = jnp.where(
is_zero, jnp.array(-jnp.inf, dtype=dtype),
is_zero, jnp.array(-np.inf, dtype=dtype),
reductions.sum(ufuncs.log(ufuncs.abs(diag)).astype(dtype), axis=-1))
return sign, ufuncs.real(logdet)
@ -539,7 +539,7 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
"""
a = ensure_arraylike("jnp.linalg.slogdet", a)
a, = promote_dtypes_inexact(a)
a_shape = jnp.shape(a)
a_shape = np.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}")
if method is None or method == "lu":
@ -610,8 +610,8 @@ def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> tuple[Array, Array]:
a, b = ensure_arraylike("jnp.linalg._cofactor_solve", a, b)
a, = promote_dtypes_inexact(a)
b, = promote_dtypes_inexact(b)
a_shape = jnp.shape(a)
b_shape = jnp.shape(b)
a_shape = np.shape(a)
b_shape = np.shape(b)
a_ndims = len(a_shape)
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
and b_shape[-2:] == a_shape[-2:]):
@ -710,7 +710,7 @@ def det(a: ArrayLike) -> Array:
"""
a = ensure_arraylike("jnp.linalg.det", a)
a, = promote_dtypes_inexact(a)
a_shape = jnp.shape(a)
a_shape = np.shape(a)
if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
return _det_2x2(a)
elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3:
@ -976,10 +976,10 @@ def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False)
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
# Singular values less than or equal to ``rtol * largest_singular_value``
# are set to zero.
rtol = lax.expand_dims(rtol[..., jnp.newaxis], range(s.ndim - rtol.ndim - 1))
rtol = lax.expand_dims(rtol[..., np.newaxis], range(s.ndim - rtol.ndim - 1))
cutoff = rtol * s[..., 0:1]
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
res = tensor_contractions.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]),
s = jnp.where(s > cutoff, s, np.inf).astype(u.dtype)
res = tensor_contractions.matmul(vh.mT, ufuncs.divide(u.mT, s[..., np.newaxis]),
precision=lax.Precision.HIGHEST)
return lax.convert_element_type(res, arr.dtype)
@ -1148,7 +1148,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,
"""
x = ensure_arraylike("jnp.linalg.norm", x)
x, = promote_dtypes_inexact(x)
x_shape = jnp.shape(x)
x_shape = np.shape(x)
ndim = len(x_shape)
if axis is None:
@ -1181,12 +1181,12 @@ def norm(x: ArrayLike, ord: int | str | None = None,
col_axis -= 1
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
elif ord == jnp.inf:
elif ord == np.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
elif ord == -jnp.inf:
elif ord == -np.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
@ -1392,7 +1392,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
rank = mask.sum()
safe_s = jnp.where(mask, s, 1).astype(a.dtype)
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, np.newaxis]
uTb = tensor_contractions.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = tensor_contractions.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
# Numpy returns empty residuals in some cases. To allow compilation, we
@ -1651,9 +1651,9 @@ def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, k
if ord is None or ord == 2:
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == jnp.inf:
elif ord == np.inf:
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == -jnp.inf:
elif ord == -np.inf:
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
@ -2177,7 +2177,7 @@ def cond(x: ArrayLike, p=None):
raise ValueError(f"jnp.linalg.cond: for {p=}, array must be square; got {arr.shape=}")
r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1))
# Convert NaNs to infs where original array has no NaNs.
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r)
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), np.inf, r)
@export

View File

@ -19,6 +19,8 @@ import re
from typing import Any
import warnings
import numpy as np
from jax._src import api
from jax._src import config
from jax import lax
@ -140,7 +142,7 @@ def _check_output_dims(
"""Check that output core dimensions match the signature."""
def wrapped(*args):
out = func(*args)
out_shapes = map(jnp.shape, out if isinstance(out, tuple) else [out])
out_shapes = map(np.shape, out if isinstance(out, tuple) else [out])
if expected_output_core_dims is None:
output_core_dims = [()] * len(out_shapes)

View File

@ -98,7 +98,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
FutureWarning)
idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = indexing.index_to_gather(jnp.shape(x), idx,
indexer = indexing.index_to_gather(np.shape(x), idx,
normalize_indices=normalize_indices)
# Avoid calling scatter if the slice shape is empty, both as a fast path and
@ -110,7 +110,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
# Broadcast `y` to the slice output shape.
y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
# Collapse any `None`/`jnp.newaxis` dimensions.
# Collapse any `None`/`np.newaxis` dimensions.
y = jnp.squeeze(y, axis=indexer.newaxis_dims)
if indexer.reversed_y_dims:
y = lax.rev(y, indexer.reversed_y_dims)

View File

@ -39,9 +39,7 @@ from jax._src.numpy.lax_numpy import (
array as array,
array_equal as array_equal,
array_equiv as array_equiv,
array_repr as array_repr,
array_split as array_split,
array_str as array_str,
astype as astype,
asarray as asarray,
atleast_1d as atleast_1d,
@ -75,10 +73,7 @@ from jax._src.numpy.lax_numpy import (
digitize as digitize,
dsplit as dsplit,
dstack as dstack,
dtype as dtype,
e as e,
ediff1d as ediff1d,
euler_gamma as euler_gamma,
expand_dims as expand_dims,
extract as extract,
eye as eye,
@ -111,7 +106,6 @@ from jax._src.numpy.lax_numpy import (
identity as identity,
iinfo as iinfo,
indices as indices,
inf as inf,
insert as insert,
interp as interp,
isclose as isclose,
@ -121,7 +115,6 @@ from jax._src.numpy.lax_numpy import (
isrealobj as isrealobj,
isscalar as isscalar,
issubdtype as issubdtype,
iterable as iterable,
ix_ as ix_,
kron as kron,
lcm as lcm,
@ -132,17 +125,13 @@ from jax._src.numpy.lax_numpy import (
matrix_transpose as matrix_transpose,
meshgrid as meshgrid,
moveaxis as moveaxis,
nan as nan,
nan_to_num as nan_to_num,
nanargmax as nanargmax,
nanargmin as nanargmin,
ndim as ndim,
newaxis as newaxis,
nonzero as nonzero,
packbits as packbits,
pad as pad,
permute_dims as permute_dims,
pi as pi,
piecewise as piecewise,
printoptions as printoptions,
promote_types as promote_types,
@ -156,13 +145,9 @@ from jax._src.numpy.lax_numpy import (
rollaxis as rollaxis,
rot90 as rot90,
round as round,
save as save,
savez as savez,
searchsorted as searchsorted,
select as select,
set_printoptions as set_printoptions,
shape as shape,
size as size,
split as split,
squeeze as squeeze,
stack as stack,
@ -277,17 +262,32 @@ from jax._src.numpy.window_functions import (
kaiser as kaiser,
)
# NumPy generic scalar types:
# Some APIs come directly from NumPy:
from numpy import (
array_repr as array_repr,
array_str as array_str,
character as character,
complexfloating as complexfloating,
dtype as dtype,
e as e,
euler_gamma as euler_gamma,
flexible as flexible,
floating as floating,
generic as generic,
inexact as inexact,
inf as inf,
integer as integer,
iterable as iterable,
nan as nan,
ndim as ndim,
newaxis as newaxis,
number as number,
object_ as object_,
pi as pi,
save as save,
savez as savez,
shape as shape,
size as size,
signedinteger as signedinteger,
unsignedinteger as unsignedinteger,
)