mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
refactor: import numpy objects directly in jax.numpy
This commit is contained in:
parent
36d7f8530b
commit
33b989ac9e
@ -61,7 +61,7 @@ def _mask(x, dims, alternative=0):
|
||||
Replaces values outside those dimensions with `alternative`. `alternative` is
|
||||
broadcast with `x`.
|
||||
"""
|
||||
assert jnp.ndim(x) == len(dims)
|
||||
assert np.ndim(x) == len(dims)
|
||||
mask = None
|
||||
for i, d in enumerate(dims):
|
||||
if d is not None:
|
||||
@ -145,7 +145,7 @@ def _projector_subspace(P, H, n, rank, maxiter=2, swap=False):
|
||||
N, _ = P.shape
|
||||
negative_column_norms = -jnp_linalg.norm(P, axis=1)
|
||||
# `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN.
|
||||
negative_column_norms = _mask(negative_column_norms, (n,), jnp.nan)
|
||||
negative_column_norms = _mask(negative_column_norms, (n,), np.nan)
|
||||
sort_idxs = jnp.argsort(negative_column_norms)
|
||||
X = P[:, sort_idxs]
|
||||
# X = X[:, :rank]
|
||||
@ -397,7 +397,7 @@ def _eigh_work(H, n, termination_size, subset_by_index):
|
||||
def default_case(agenda, blocks, eigenvectors):
|
||||
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
|
||||
# TODO: Improve this?
|
||||
split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan))
|
||||
split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), np.nan))
|
||||
H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(
|
||||
H, b, split_point, V0=V)
|
||||
|
||||
@ -564,7 +564,7 @@ def eigh(
|
||||
eig_vals, eig_vecs = _eigh_work(
|
||||
H, n, termination_size=termination_size, subset_by_index=subset_by_index
|
||||
)
|
||||
eig_vals = _mask(ufuncs.real(eig_vals), (n,), jnp.nan)
|
||||
eig_vals = _mask(ufuncs.real(eig_vals), (n,), np.nan)
|
||||
if sort_eigenvalues or compute_slice:
|
||||
sort_idxs = jnp.argsort(eig_vals)
|
||||
if compute_slice:
|
||||
|
@ -82,18 +82,8 @@ for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
|
||||
else:
|
||||
break
|
||||
|
||||
newaxis = None
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
# NumPy constants
|
||||
|
||||
pi = np.pi
|
||||
e = np.e
|
||||
euler_gamma = np.euler_gamma
|
||||
inf = np.inf
|
||||
nan = np.nan
|
||||
|
||||
# Wrappers for NumPy printoptions
|
||||
|
||||
def get_printoptions():
|
||||
@ -169,9 +159,6 @@ def iscomplexobj(x: Any) -> bool:
|
||||
typ = asarray(x).dtype.type
|
||||
return issubdtype(typ, np.complexfloating)
|
||||
|
||||
shape = _shape = np.shape
|
||||
ndim = _ndim = np.ndim
|
||||
size = np.size
|
||||
|
||||
def _dtype(x: Any) -> DType:
|
||||
return dtypes.dtype(x, canonicalize=True)
|
||||
@ -180,19 +167,11 @@ def _dtype(x: Any) -> DType:
|
||||
iinfo = dtypes.iinfo
|
||||
finfo = dtypes.finfo
|
||||
|
||||
dtype = np.dtype
|
||||
can_cast = dtypes.can_cast
|
||||
promote_types = dtypes.promote_types
|
||||
|
||||
ComplexWarning = NumpyComplexWarning
|
||||
|
||||
# Numpy functions
|
||||
array_str = np.array_str
|
||||
array_repr = np.array_repr
|
||||
|
||||
save = np.save
|
||||
savez = np.savez
|
||||
|
||||
_lax_const = lax_internal._const
|
||||
|
||||
|
||||
@ -534,8 +513,6 @@ def isscalar(element: Any) -> bool:
|
||||
return asarray(element).ndim == 0
|
||||
return False
|
||||
|
||||
iterable = np.iterable
|
||||
|
||||
|
||||
@export
|
||||
def result_type(*args: Any) -> DType:
|
||||
@ -621,7 +598,7 @@ def trunc(x: ArrayLike) -> Array:
|
||||
@partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type'])
|
||||
def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike,
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
if ndim(x) != 1 or ndim(y) != 1:
|
||||
if np.ndim(x) != 1 or np.ndim(y) != 1:
|
||||
raise ValueError(f"{op}() only support 1-dimensional inputs.")
|
||||
if preferred_element_type is None:
|
||||
# if unspecified, promote to inexact following NumPy's default for convolutions.
|
||||
@ -856,7 +833,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
|
||||
util.check_arraylike("histogram_bin_edges", a, bins)
|
||||
arr = asarray(a)
|
||||
dtype = dtypes.to_inexact_dtype(arr.dtype)
|
||||
if _ndim(bins) == 1:
|
||||
if np.ndim(bins) == 1:
|
||||
return asarray(bins, dtype=dtype)
|
||||
|
||||
bins_int = core.concrete_or_error(operator.index, bins,
|
||||
@ -864,7 +841,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
|
||||
if range is None:
|
||||
range = [arr.min(), arr.max()]
|
||||
range = asarray(range, dtype=dtype)
|
||||
if shape(range) != (2,):
|
||||
if np.shape(range) != (2,):
|
||||
raise ValueError(f"`range` must be either None or a sequence of scalars, got {range}")
|
||||
range = (where(reductions.ptp(range) == 0, range[0] - 0.5, range[0]),
|
||||
where(reductions.ptp(range) == 0, range[1] + 0.5, range[1]))
|
||||
@ -940,7 +917,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10,
|
||||
weights = ones_like(a)
|
||||
else:
|
||||
util.check_arraylike("histogram", a, bins, weights)
|
||||
if shape(a) != shape(weights):
|
||||
if np.shape(a) != np.shape(weights):
|
||||
raise ValueError("weights should have the same shape as a.")
|
||||
a, weights = util.promote_dtypes_inexact(a, weights)
|
||||
|
||||
@ -1105,13 +1082,13 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
|
||||
sample, = util.promote_dtypes_inexact(sample)
|
||||
else:
|
||||
util.check_arraylike("histogramdd", sample, weights)
|
||||
if shape(weights) != shape(sample)[:1]:
|
||||
if np.shape(weights) != np.shape(sample)[:1]:
|
||||
raise ValueError("should have one weight for each sample.")
|
||||
sample, weights = util.promote_dtypes_inexact(sample, weights)
|
||||
N, D = shape(sample)
|
||||
N, D = np.shape(sample)
|
||||
|
||||
if range is not None and (
|
||||
len(range) != D or any(r is not None and shape(r)[0] != 2 for r in range)): # type: ignore[arg-type]
|
||||
len(range) != D or any(r is not None and np.shape(r)[0] != 2 for r in range)): # type: ignore[arg-type]
|
||||
raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence "
|
||||
f"of {D} pairs or Nones; got {range=}")
|
||||
|
||||
@ -1228,8 +1205,8 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
|
||||
[2, 4]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("transpose", a)
|
||||
axes_ = list(range(ndim(a))[::-1]) if axes is None else axes
|
||||
axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_]
|
||||
axes_ = list(range(np.ndim(a))[::-1]) if axes is None else axes
|
||||
axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_]
|
||||
return lax.transpose(a, axes_)
|
||||
|
||||
|
||||
@ -1383,8 +1360,8 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
|
||||
f"two, but got first argument of shape {np.shape(m)}, "
|
||||
f"which has ndim {np.ndim(m)}")
|
||||
ax1, ax2 = axes
|
||||
ax1 = _canonicalize_axis(ax1, ndim(m))
|
||||
ax2 = _canonicalize_axis(ax2, ndim(m))
|
||||
ax1 = _canonicalize_axis(ax1, np.ndim(m))
|
||||
ax2 = _canonicalize_axis(ax2, np.ndim(m))
|
||||
if ax1 == ax2:
|
||||
raise ValueError("Axes must be different") # same as numpy error
|
||||
k = k % 4
|
||||
@ -1393,7 +1370,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
|
||||
elif k == 2:
|
||||
return flip(flip(m, ax1), ax2)
|
||||
else:
|
||||
perm = list(range(ndim(m)))
|
||||
perm = list(range(np.ndim(m)))
|
||||
perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
|
||||
if k == 1:
|
||||
return transpose(flip(m, ax2), perm)
|
||||
@ -1464,9 +1441,9 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
|
||||
if axis is None:
|
||||
return lax.rev(m, list(range(len(shape(m)))))
|
||||
return lax.rev(m, list(range(len(np.shape(m)))))
|
||||
axis = _ensure_index_tuple(axis)
|
||||
return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis])
|
||||
return lax.rev(m, [_canonicalize_axis(ax, np.ndim(m)) for ax in axis])
|
||||
|
||||
|
||||
@export
|
||||
@ -1617,7 +1594,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array:
|
||||
im = ufuncs.imag(z)
|
||||
dtype = _dtype(re)
|
||||
if not issubdtype(dtype, np.inexact) or (
|
||||
issubdtype(_dtype(z), np.floating) and ndim(z) == 0):
|
||||
issubdtype(_dtype(z), np.floating) and np.ndim(z) == 0):
|
||||
dtype = dtypes.canonicalize_dtype(dtypes.float_)
|
||||
re = lax.convert_element_type(re, dtype)
|
||||
im = lax.convert_element_type(im, dtype)
|
||||
@ -1704,7 +1681,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
combined: list[Array] = []
|
||||
if prepend is not None:
|
||||
prepend = util.ensure_arraylike("diff", prepend)
|
||||
if not ndim(prepend):
|
||||
if not np.ndim(prepend):
|
||||
shape = list(arr.shape)
|
||||
shape[axis] = 1
|
||||
prepend = broadcast_to(prepend, tuple(shape))
|
||||
@ -1714,7 +1691,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
|
||||
if append is not None:
|
||||
append = util.ensure_arraylike("diff", append)
|
||||
if not ndim(append):
|
||||
if not np.ndim(append):
|
||||
shape = list(arr.shape)
|
||||
shape[axis] = 1
|
||||
append = broadcast_to(append, tuple(shape))
|
||||
@ -1878,12 +1855,12 @@ def gradient(
|
||||
upper_edge = sliced(1, 2) - sliced(0, 1)
|
||||
lower_edge = sliced(-1, None) - sliced(-2, -1)
|
||||
|
||||
if ndim(h) == 0:
|
||||
if np.ndim(h) == 0:
|
||||
inner = (sliced(2, None) - sliced(None, -2)) * 0.5 / h
|
||||
lower_edge /= h
|
||||
upper_edge /= h
|
||||
|
||||
elif ndim(h) == 1:
|
||||
elif np.ndim(h) == 1:
|
||||
if len(h) != a.shape[axis]:
|
||||
raise ValueError(
|
||||
"Spacing arrays must have the same length as the "
|
||||
@ -2112,7 +2089,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array:
|
||||
util.check_arraylike("ravel", a)
|
||||
if order == "K":
|
||||
raise NotImplementedError("Ravel not implemented for order='K'.")
|
||||
return reshape(a, (size(a),), order)
|
||||
return reshape(a, (np.size(a),), order)
|
||||
|
||||
|
||||
@export
|
||||
@ -2259,7 +2236,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
|
||||
# TODO: Consider warning here since shape is supposed to be a sequence, so
|
||||
# this should not happen.
|
||||
shape = [shape]
|
||||
if any(ndim(s) != 0 for s in shape):
|
||||
if any(np.ndim(s) != 0 for s in shape):
|
||||
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
|
||||
out_indices: list[ArrayLike] = [0] * len(shape)
|
||||
for i, s in reversed(list(enumerate(shape))):
|
||||
@ -2385,7 +2362,7 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
|
||||
@partial(jit, static_argnames=('axis',), inline=True)
|
||||
def _squeeze(a: Array, axis: tuple[int, ...]) -> Array:
|
||||
if axis is None:
|
||||
a_shape = shape(a)
|
||||
a_shape = np.shape(a)
|
||||
if not core.is_constant_shape(a_shape):
|
||||
# We do not even know the rank of the output if the input shape is not known
|
||||
raise ValueError("jnp.squeeze with axis=None is not supported with shape polymorphism")
|
||||
@ -2507,7 +2484,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
|
||||
(2, 5, 4, 3)
|
||||
"""
|
||||
util.check_arraylike("swapaxes", a)
|
||||
perm = np.arange(ndim(a))
|
||||
perm = np.arange(np.ndim(a))
|
||||
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
|
||||
return lax.transpose(a, list(perm))
|
||||
|
||||
@ -2567,12 +2544,12 @@ def moveaxis(a: ArrayLike, source: int | Sequence[int],
|
||||
|
||||
@partial(jit, static_argnames=('source', 'destination'), inline=True)
|
||||
def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -> Array:
|
||||
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
|
||||
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
|
||||
source = tuple(_canonicalize_axis(i, np.ndim(a)) for i in source)
|
||||
destination = tuple(_canonicalize_axis(i, np.ndim(a)) for i in destination)
|
||||
if len(source) != len(destination):
|
||||
raise ValueError("Inconsistent number of elements: {} vs {}"
|
||||
.format(len(source), len(destination)))
|
||||
perm = [i for i in range(ndim(a)) if i not in source]
|
||||
perm = [i for i in range(np.ndim(a)) if i not in source]
|
||||
for dest, src in sorted(zip(destination, source)):
|
||||
perm.insert(dest, src)
|
||||
return lax.transpose(a, perm)
|
||||
@ -2666,7 +2643,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
right: ArrayLike | str | None = None,
|
||||
period: ArrayLike | None = None) -> Array:
|
||||
util.check_arraylike("interp", x, xp, fp)
|
||||
if shape(xp) != shape(fp) or ndim(xp) != 1:
|
||||
if np.shape(xp) != np.shape(fp) or np.ndim(xp) != 1:
|
||||
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
|
||||
x_arr, xp_arr = util.promote_dtypes_inexact(x, xp)
|
||||
fp_arr, = util.promote_dtypes_inexact(fp)
|
||||
@ -2691,7 +2668,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
raise ValueError("jnp.interp: complex x values not supported.")
|
||||
|
||||
if period is not None:
|
||||
if ndim(period) != 0:
|
||||
if np.ndim(period) != 0:
|
||||
raise ValueError(f"period must be a scalar; got {period}")
|
||||
period = ufuncs.abs(period)
|
||||
x_arr = x_arr % period
|
||||
@ -3018,7 +2995,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None,
|
||||
x = lax.convert_element_type(x, 'int32')
|
||||
if not issubdtype(_dtype(x), np.integer):
|
||||
raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
|
||||
if ndim(x) != 1:
|
||||
if np.ndim(x) != 1:
|
||||
raise ValueError("only 1-dimensional input supported.")
|
||||
minlength = core.concrete_or_error(operator.index, minlength,
|
||||
"The error occurred because of argument 'minlength' of jnp.bincount.")
|
||||
@ -3032,7 +3009,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None,
|
||||
"The error occurred because of argument 'length' of jnp.bincount.")
|
||||
if weights is None:
|
||||
weights = np.array(1, dtype=dtypes.int_)
|
||||
elif shape(x) != shape(weights):
|
||||
elif np.shape(x) != np.shape(weights):
|
||||
raise ValueError("shape of weights must match shape of x.")
|
||||
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights, mode='drop')
|
||||
|
||||
@ -3789,7 +3766,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
|
||||
"""
|
||||
arr = util.ensure_arraylike("nonzero", a)
|
||||
del a
|
||||
if ndim(arr) == 0:
|
||||
if np.ndim(arr) == 0:
|
||||
raise ValueError("Calling nonzero on 0d arrays is not allowed. "
|
||||
"Use jnp.atleast_1d(scalar).nonzero() instead.")
|
||||
mask = arr if arr.dtype == bool else (arr != 0)
|
||||
@ -3805,7 +3782,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
|
||||
out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape))
|
||||
if fill_value is not None:
|
||||
fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,)
|
||||
if any(_shape(val) != () for val in fill_value_tup):
|
||||
if any(np.shape(val) != () for val in fill_value_tup):
|
||||
raise ValueError(f"fill_value must be a scalar or a tuple of length {arr.ndim}; got {fill_value}")
|
||||
fill_mask = arange(calculated_size) >= mask.sum()
|
||||
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out))
|
||||
@ -3861,7 +3838,7 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None,
|
||||
@export
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def unwrap(p: ArrayLike, discont: ArrayLike | None = None,
|
||||
axis: int = -1, period: ArrayLike = 2 * pi) -> Array:
|
||||
axis: int = -1, period: ArrayLike = 2 * np.pi) -> Array:
|
||||
"""Unwrap a periodic signal.
|
||||
|
||||
JAX implementation of :func:`numpy.unwrap`.
|
||||
@ -3997,10 +3974,10 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str):
|
||||
|
||||
|
||||
def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array:
|
||||
nd = ndim(array)
|
||||
nd = np.ndim(array)
|
||||
constant_values = lax_internal._convert_element_type(
|
||||
constant_values, array.dtype, dtypes.is_weakly_typed(array))
|
||||
constant_values_nd = ndim(constant_values)
|
||||
constant_values_nd = np.ndim(constant_values)
|
||||
|
||||
if constant_values_nd == 0:
|
||||
widths = [(low, high, 0) for (low, high) in pad_width]
|
||||
@ -4033,7 +4010,7 @@ def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array
|
||||
|
||||
|
||||
def _pad_wrap(array: Array, pad_width: PadValue[int]) -> Array:
|
||||
for i in range(ndim(array)):
|
||||
for i in range(np.ndim(array)):
|
||||
if array.shape[i] == 0:
|
||||
_check_no_padding(pad_width[i], "wrap")
|
||||
continue
|
||||
@ -4056,7 +4033,7 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int],
|
||||
assert mode in ("symmetric", "reflect")
|
||||
assert reflect_type in ("even", "odd")
|
||||
|
||||
for i in range(ndim(array)):
|
||||
for i in range(np.ndim(array)):
|
||||
if array.shape[i] == 0:
|
||||
_check_no_padding(pad_width[i], mode)
|
||||
continue
|
||||
@ -4121,7 +4098,7 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int],
|
||||
|
||||
|
||||
def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array:
|
||||
nd = ndim(array)
|
||||
nd = np.ndim(array)
|
||||
for i in range(nd):
|
||||
if array.shape[i] == 0:
|
||||
_check_no_padding(pad_width[i], "edge")
|
||||
@ -4142,7 +4119,7 @@ def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array:
|
||||
|
||||
def _pad_linear_ramp(array: Array, pad_width: PadValue[int],
|
||||
end_values: PadValue[ArrayLike]) -> Array:
|
||||
for axis in range(ndim(array)):
|
||||
for axis in range(np.ndim(array)):
|
||||
edge_before = lax.slice_in_dim(array, 0, 1, axis=axis)
|
||||
edge_after = lax.slice_in_dim(array, -1, None, axis=axis)
|
||||
ramp_before = linspace(
|
||||
@ -4176,7 +4153,7 @@ def _pad_linear_ramp(array: Array, pad_width: PadValue[int],
|
||||
def _pad_stats(array: Array, pad_width: PadValue[int],
|
||||
stat_length: PadValue[int] | None,
|
||||
stat_func: PadStatFunc) -> Array:
|
||||
nd = ndim(array)
|
||||
nd = np.ndim(array)
|
||||
for i in range(nd):
|
||||
if stat_length is None:
|
||||
stat_before = stat_func(array, axis=i, keepdims=True)
|
||||
@ -4215,7 +4192,7 @@ def _pad_stats(array: Array, pad_width: PadValue[int],
|
||||
|
||||
def _pad_empty(array: Array, pad_width: PadValue[int]) -> Array:
|
||||
# Note: jax.numpy.empty = jax.numpy.zeros
|
||||
for i in range(ndim(array)):
|
||||
for i in range(np.ndim(array)):
|
||||
shape_before = array.shape[:i] + (pad_width[i][0],) + array.shape[i + 1:]
|
||||
pad_before = empty_like(array, shape=shape_before)
|
||||
|
||||
@ -4226,9 +4203,9 @@ def _pad_empty(array: Array, pad_width: PadValue[int]) -> Array:
|
||||
|
||||
|
||||
def _pad_func(array: Array, pad_width: PadValue[int], func: Callable[..., Any], **kwargs) -> Array:
|
||||
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
|
||||
pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width")
|
||||
padded = _pad_constant(array, pad_width, asarray(0))
|
||||
for axis in range(ndim(padded)):
|
||||
for axis in range(np.ndim(padded)):
|
||||
padded = apply_along_axis(func, axis, padded, pad_width[axis], axis, kwargs)
|
||||
return padded
|
||||
|
||||
@ -4238,7 +4215,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str,
|
||||
constant_values: ArrayLike, stat_length: PadValueLike[int],
|
||||
end_values: PadValueLike[ArrayLike], reflect_type: str):
|
||||
array = asarray(array)
|
||||
nd = ndim(array)
|
||||
nd = np.ndim(array)
|
||||
|
||||
if nd == 0:
|
||||
return array
|
||||
@ -4406,7 +4383,7 @@ def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray],
|
||||
"""
|
||||
|
||||
util.check_arraylike("pad", array)
|
||||
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
|
||||
pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width")
|
||||
if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1])
|
||||
for p in pad_width):
|
||||
raise TypeError('`pad_width` must be of integral type.')
|
||||
@ -4501,11 +4478,11 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
|
||||
return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
|
||||
else:
|
||||
util.check_arraylike("stack", *arrays)
|
||||
shape0 = shape(arrays[0])
|
||||
shape0 = np.shape(arrays[0])
|
||||
axis = _canonicalize_axis(axis, len(shape0) + 1)
|
||||
new_arrays = []
|
||||
for a in arrays:
|
||||
if shape(a) != shape0:
|
||||
if np.shape(a) != shape0:
|
||||
raise ValueError("All input arrays must have the same shape.")
|
||||
new_arrays.append(expand_dims(a, axis))
|
||||
return concatenate(new_arrays, axis=axis, dtype=dtype)
|
||||
@ -4598,7 +4575,7 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
|
||||
reps_tup = tuple(reps) # type: ignore[arg-type]
|
||||
reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
|
||||
for rep in reps_tup)
|
||||
A_shape = (1,) * (len(reps_tup) - ndim(A)) + shape(A)
|
||||
A_shape = (1,) * (len(reps_tup) - np.ndim(A)) + np.shape(A)
|
||||
reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup
|
||||
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
|
||||
[k for pair in zip(reps_tup, A_shape) for k in pair])
|
||||
@ -4667,9 +4644,9 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
|
||||
raise ValueError("Need at least one array to concatenate.")
|
||||
if axis is None:
|
||||
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
|
||||
if ndim(arrays[0]) == 0:
|
||||
if np.ndim(arrays[0]) == 0:
|
||||
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
|
||||
axis = _canonicalize_axis(axis, ndim(arrays[0]))
|
||||
axis = _canonicalize_axis(axis, np.ndim(arrays[0]))
|
||||
if dtype is None:
|
||||
arrays_out = util.promote_dtypes(*arrays)
|
||||
else:
|
||||
@ -5074,7 +5051,7 @@ def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike],
|
||||
|
||||
|
||||
def _atleast_nd(x: ArrayLike, n: int) -> Array:
|
||||
m = ndim(x)
|
||||
m = np.ndim(x)
|
||||
return lax.broadcast(x, (1,) * (n - m)) if m < n else asarray(x)
|
||||
|
||||
def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]:
|
||||
@ -5087,7 +5064,7 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]:
|
||||
xs_tup, depths = unzip2([_block(x) for x in xs])
|
||||
if any(d != depths[0] for d in depths[1:]):
|
||||
raise ValueError("Mismatched list depths in jax.numpy.block")
|
||||
rank = max(depths[0], max(ndim(x) for x in xs_tup))
|
||||
rank = max(depths[0], max(np.ndim(x) for x in xs_tup))
|
||||
xs_tup = tuple(_atleast_nd(x, rank) for x in xs_tup)
|
||||
return concatenate(xs_tup, axis=-depths[0]), depths[0] + 1
|
||||
else:
|
||||
@ -5589,8 +5566,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
raise TypeError(f"Unexpected input type for array: {type(object)}")
|
||||
out_array: Array = lax_internal._convert_element_type(
|
||||
out, dtype, weak_type=weak_type, sharding=sharding)
|
||||
if ndmin > ndim(out_array):
|
||||
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
|
||||
if ndmin > np.ndim(out_array):
|
||||
out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array)))
|
||||
return out_array
|
||||
|
||||
|
||||
@ -5839,7 +5816,7 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
a1, a2 = asarray(a1), asarray(a2)
|
||||
if shape(a1) != shape(a2):
|
||||
if np.shape(a1) != np.shape(a2):
|
||||
return array(False, dtype=bool)
|
||||
eq = asarray(a1 == a2)
|
||||
if equal_nan:
|
||||
@ -6519,7 +6496,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
start = start.astype(computation_dtype)
|
||||
stop = stop.astype(computation_dtype)
|
||||
|
||||
bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
|
||||
bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop)))
|
||||
broadcast_start = broadcast_to(start, bounds_shape)
|
||||
broadcast_stop = broadcast_to(stop, bounds_shape)
|
||||
axis = len(bounds_shape) + axis + 1 if axis < 0 else axis
|
||||
@ -6542,12 +6519,12 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
_canonicalize_axis(axis, out.ndim))
|
||||
|
||||
elif num == 1:
|
||||
delta = asarray(nan if endpoint else stop - start, dtype=computation_dtype)
|
||||
delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype)
|
||||
out = reshape(broadcast_start, bounds_shape)
|
||||
else: # num == 0 degenerate case, match numpy behavior
|
||||
empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
|
||||
empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop)))
|
||||
empty_shape.insert(axis, 0)
|
||||
delta = asarray(nan, dtype=computation_dtype)
|
||||
delta = asarray(np.nan, dtype=computation_dtype)
|
||||
out = reshape(array([], dtype=dtype), empty_shape)
|
||||
|
||||
if issubdtype(dtype, np.integer) and not issubdtype(out.dtype, np.integer):
|
||||
@ -7053,7 +7030,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
||||
"value to `total_repeat_length`.")
|
||||
|
||||
# Fast path for when repeats is a scalar.
|
||||
if np.ndim(repeats) == 0 and ndim(arr) != 0:
|
||||
if np.ndim(repeats) == 0 and np.ndim(arr) != 0:
|
||||
input_shape = arr.shape
|
||||
axis = _canonicalize_axis(axis, len(input_shape))
|
||||
aux_axis = axis + 1
|
||||
@ -7076,7 +7053,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
||||
|
||||
# Special case when a is a scalar.
|
||||
if arr.ndim == 0:
|
||||
if shape(repeats) == (1,):
|
||||
if np.shape(repeats) == (1,):
|
||||
return full([total_repeat_length], arr)
|
||||
else:
|
||||
raise ValueError('`repeat` with a scalar parameter `a` is only '
|
||||
@ -7279,7 +7256,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array:
|
||||
[7, 8]]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("tril", m)
|
||||
m_shape = shape(m)
|
||||
m_shape = np.shape(m)
|
||||
if len(m_shape) < 2:
|
||||
raise ValueError("Argument to jax.numpy.tril must be at least 2D")
|
||||
N, M = m_shape[-2:]
|
||||
@ -7346,7 +7323,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array:
|
||||
[0, 8]]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("triu", m)
|
||||
m_shape = shape(m)
|
||||
m_shape = np.shape(m)
|
||||
if len(m_shape) < 2:
|
||||
raise ValueError("Argument to jax.numpy.triu must be at least 2D")
|
||||
N, M = m_shape[-2:]
|
||||
@ -7406,12 +7383,12 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
|
||||
|
||||
if _canonicalize_axis(axis1, ndim(a)) == _canonicalize_axis(axis2, ndim(a)):
|
||||
if _canonicalize_axis(axis1, np.ndim(a)) == _canonicalize_axis(axis2, np.ndim(a)):
|
||||
raise ValueError(f"axis1 and axis2 can not be same. axis1={axis1} and axis2={axis2}")
|
||||
|
||||
dtypes.check_user_dtype_supported(dtype, "trace")
|
||||
|
||||
a_shape = shape(a)
|
||||
a_shape = np.shape(a)
|
||||
a = moveaxis(a, (axis1, axis2), (-2, -1))
|
||||
|
||||
# Mask out the diagonal and reduce.
|
||||
@ -7650,7 +7627,7 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
|
||||
>>> jnp.triu_indices_from(arr, k=-1)
|
||||
(Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))
|
||||
"""
|
||||
arr_shape = shape(arr)
|
||||
arr_shape = np.shape(arr)
|
||||
if len(arr_shape) != 2:
|
||||
raise ValueError("Only 2-D inputs are accepted")
|
||||
return triu_indices(arr_shape[0], k=k, m=arr_shape[1])
|
||||
@ -7708,7 +7685,7 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
|
||||
>>> jnp.tril_indices_from(arr, k=-1)
|
||||
(Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))
|
||||
"""
|
||||
arr_shape = shape(arr)
|
||||
arr_shape = np.shape(arr)
|
||||
if len(arr_shape) != 2:
|
||||
raise ValueError("Only 2-D inputs are accepted")
|
||||
return tril_indices(arr_shape[0], k=k, m=arr_shape[1])
|
||||
@ -7863,12 +7840,12 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]:
|
||||
Array([0, 1], dtype=int32))
|
||||
"""
|
||||
util.check_arraylike("diag_indices_from", arr)
|
||||
nd = ndim(arr)
|
||||
if not ndim(arr) >= 2:
|
||||
nd = np.ndim(arr)
|
||||
if not np.ndim(arr) >= 2:
|
||||
raise ValueError("input array must be at least 2-d")
|
||||
|
||||
s = shape(arr)
|
||||
if len(set(shape(arr))) != 1:
|
||||
s = np.shape(arr)
|
||||
if len(set(np.shape(arr))) != 1:
|
||||
raise ValueError("All dimensions of input must be of equal length")
|
||||
|
||||
return diag_indices(s[0], ndim=nd)
|
||||
@ -7913,12 +7890,12 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
|
||||
"""
|
||||
util.check_arraylike("diagonal", a)
|
||||
|
||||
if ndim(a) < 2:
|
||||
if np.ndim(a) < 2:
|
||||
raise ValueError("diagonal requires an array of at least two dimensions.")
|
||||
offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()")
|
||||
|
||||
def _default_diag(a):
|
||||
a_shape = shape(a)
|
||||
a_shape = np.shape(a)
|
||||
|
||||
a = moveaxis(a, (axis1, axis2), (-2, -1))
|
||||
|
||||
@ -7932,10 +7909,10 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
|
||||
|
||||
# The mosaic lowering rule for diag is only defined for square arrays.
|
||||
# TODO(mvoz): Add support for offsets.
|
||||
if shape(a)[0] != shape(a)[1] or ndim(a) != 2 or offset != 0 or _dtype(a) == bool:
|
||||
if np.shape(a)[0] != np.shape(a)[1] or np.ndim(a) != 2 or offset != 0 or _dtype(a) == bool:
|
||||
return _default_diag(a)
|
||||
else:
|
||||
a_shape_eye = eye(shape(a)[0], dtype=_dtype(a))
|
||||
a_shape_eye = eye(np.shape(a)[0], dtype=_dtype(a))
|
||||
|
||||
def _mosaic_diag(a):
|
||||
def _sum(x, axis):
|
||||
@ -8002,7 +7979,7 @@ def diag(v: ArrayLike, k: int = 0) -> Array:
|
||||
@partial(jit, static_argnames=('k',))
|
||||
def _diag(v, k):
|
||||
util.check_arraylike("diag", v)
|
||||
v_shape = shape(v)
|
||||
v_shape = np.shape(v)
|
||||
if len(v_shape) == 1:
|
||||
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
|
||||
n = v_shape[0] + abs(k)
|
||||
@ -8472,7 +8449,7 @@ def apply_along_axis(
|
||||
Array([ 65, 133, 243], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("apply_along_axis", arr)
|
||||
num_dims = ndim(arr)
|
||||
num_dims = np.ndim(arr)
|
||||
axis = _canonicalize_axis(axis, num_dims)
|
||||
func = lambda arr: func1d(arr, *args, **kwargs)
|
||||
for i in range(1, num_dims - axis):
|
||||
@ -8675,13 +8652,13 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array:
|
||||
"""
|
||||
util.check_arraylike("kron", a, b)
|
||||
a, b = util.promote_dtypes(a, b)
|
||||
if ndim(a) < ndim(b):
|
||||
a = expand_dims(a, range(ndim(b) - ndim(a)))
|
||||
elif ndim(b) < ndim(a):
|
||||
b = expand_dims(b, range(ndim(a) - ndim(b)))
|
||||
a_reshaped = expand_dims(a, range(1, 2 * ndim(a), 2))
|
||||
b_reshaped = expand_dims(b, range(0, 2 * ndim(b), 2))
|
||||
out_shape = tuple(np.multiply(shape(a), shape(b)))
|
||||
if np.ndim(a) < np.ndim(b):
|
||||
a = expand_dims(a, range(np.ndim(b) - np.ndim(a)))
|
||||
elif np.ndim(b) < np.ndim(a):
|
||||
b = expand_dims(b, range(np.ndim(a) - np.ndim(b)))
|
||||
a_reshaped = expand_dims(a, range(1, 2 * np.ndim(a), 2))
|
||||
b_reshaped = expand_dims(b, range(0, 2 * np.ndim(b), 2))
|
||||
out_shape = tuple(np.multiply(np.shape(a), np.shape(b)))
|
||||
return reshape(lax.mul(a_reshaped, b_reshaped), out_shape)
|
||||
|
||||
|
||||
@ -8809,9 +8786,9 @@ def argwhere(
|
||||
Array([], shape=(0, 0), dtype=int32)
|
||||
"""
|
||||
result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value)))
|
||||
if ndim(a) == 0:
|
||||
if np.ndim(a) == 0:
|
||||
return result[:0].reshape(result.shape[0], 0)
|
||||
return result.reshape(result.shape[0], ndim(a))
|
||||
return result.reshape(result.shape[0], np.ndim(a))
|
||||
|
||||
|
||||
@export
|
||||
@ -8859,7 +8836,7 @@ def argmax(a: ArrayLike, axis: int | None = None, out: None = None,
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array:
|
||||
if axis is None:
|
||||
dims = list(range(ndim(a)))
|
||||
dims = list(range(np.ndim(a)))
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
else:
|
||||
@ -8915,7 +8892,7 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None,
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array:
|
||||
if axis is None:
|
||||
dims = list(range(ndim(a)))
|
||||
dims = list(range(np.ndim(a)))
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
else:
|
||||
@ -8989,7 +8966,7 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False):
|
||||
if not issubdtype(_dtype(a), np.inexact):
|
||||
return argmax(a, axis=axis, keepdims=keepdims)
|
||||
nan_mask = ufuncs.isnan(a)
|
||||
a = where(nan_mask, -inf, a)
|
||||
a = where(nan_mask, -np.inf, a)
|
||||
res = argmax(a, axis=axis, keepdims=keepdims)
|
||||
return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
|
||||
|
||||
@ -9050,7 +9027,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False):
|
||||
if not issubdtype(_dtype(a), np.inexact):
|
||||
return argmin(a, axis=axis, keepdims=keepdims)
|
||||
nan_mask = ufuncs.isnan(a)
|
||||
a = where(nan_mask, inf, a)
|
||||
a = where(nan_mask, np.inf, a)
|
||||
res = argmin(a, axis=axis, keepdims=keepdims)
|
||||
return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
|
||||
|
||||
@ -9191,7 +9168,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array:
|
||||
"""
|
||||
util.check_arraylike("rollaxis", a)
|
||||
start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()")
|
||||
a_ndim = ndim(a)
|
||||
a_ndim = np.ndim(a)
|
||||
axis = _canonicalize_axis(axis, a_ndim)
|
||||
if not (-a_ndim <= start <= a_ndim):
|
||||
raise ValueError(f"{start=} must satisfy {-a_ndim}<=start<={a_ndim}")
|
||||
@ -9764,9 +9741,9 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True,
|
||||
w: Array | None = None
|
||||
if fweights is not None:
|
||||
fweights = util.ensure_arraylike("cov", fweights)
|
||||
if ndim(fweights) > 1:
|
||||
if np.ndim(fweights) > 1:
|
||||
raise RuntimeError("cannot handle multidimensional fweights")
|
||||
if shape(fweights)[0] != X.shape[1]:
|
||||
if np.shape(fweights)[0] != X.shape[1]:
|
||||
raise RuntimeError("incompatible numbers of samples and fweights")
|
||||
if not issubdtype(_dtype(fweights), np.integer):
|
||||
raise TypeError("fweights must be integer.")
|
||||
@ -9774,9 +9751,9 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True,
|
||||
w = abs(fweights)
|
||||
if aweights is not None:
|
||||
aweights = util.ensure_arraylike("cov", aweights)
|
||||
if ndim(aweights) > 1:
|
||||
if np.ndim(aweights) > 1:
|
||||
raise RuntimeError("cannot handle multidimensional aweights")
|
||||
if shape(aweights)[0] != X.shape[1]:
|
||||
if np.shape(aweights)[0] != X.shape[1]:
|
||||
raise RuntimeError("incompatible numbers of samples and aweights")
|
||||
# Ensure positive aweights: note that numpy raises an error for negative aweights.
|
||||
aweights = abs(aweights)
|
||||
@ -9877,7 +9854,7 @@ def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> A
|
||||
"""
|
||||
util.check_arraylike("corrcoef", x)
|
||||
c = cov(x, y, rowvar)
|
||||
if len(shape(c)) == 0:
|
||||
if len(np.shape(c)) == 0:
|
||||
# scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise
|
||||
return ufuncs.divide(c, c)
|
||||
d = diag(c)
|
||||
@ -10002,7 +9979,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
|
||||
raise ValueError(
|
||||
f"{method!r} is an invalid value for keyword 'method'. "
|
||||
"Expected one of ['sort', 'scan', 'scan_unrolled', 'compare_all'].")
|
||||
if ndim(a) != 1:
|
||||
if np.ndim(a) != 1:
|
||||
raise ValueError("a should be 1-dimensional")
|
||||
a, v = util.promote_dtypes(a, v)
|
||||
if sorter is not None:
|
||||
|
@ -478,7 +478,7 @@ def _slogdet_lu(a: Array) -> tuple[Array, Array]:
|
||||
jnp.array(0, dtype=dtype),
|
||||
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
|
||||
logdet = jnp.where(
|
||||
is_zero, jnp.array(-jnp.inf, dtype=dtype),
|
||||
is_zero, jnp.array(-np.inf, dtype=dtype),
|
||||
reductions.sum(ufuncs.log(ufuncs.abs(diag)).astype(dtype), axis=-1))
|
||||
return sign, ufuncs.real(logdet)
|
||||
|
||||
@ -539,7 +539,7 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
|
||||
"""
|
||||
a = ensure_arraylike("jnp.linalg.slogdet", a)
|
||||
a, = promote_dtypes_inexact(a)
|
||||
a_shape = jnp.shape(a)
|
||||
a_shape = np.shape(a)
|
||||
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
||||
raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}")
|
||||
if method is None or method == "lu":
|
||||
@ -610,8 +610,8 @@ def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> tuple[Array, Array]:
|
||||
a, b = ensure_arraylike("jnp.linalg._cofactor_solve", a, b)
|
||||
a, = promote_dtypes_inexact(a)
|
||||
b, = promote_dtypes_inexact(b)
|
||||
a_shape = jnp.shape(a)
|
||||
b_shape = jnp.shape(b)
|
||||
a_shape = np.shape(a)
|
||||
b_shape = np.shape(b)
|
||||
a_ndims = len(a_shape)
|
||||
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
|
||||
and b_shape[-2:] == a_shape[-2:]):
|
||||
@ -710,7 +710,7 @@ def det(a: ArrayLike) -> Array:
|
||||
"""
|
||||
a = ensure_arraylike("jnp.linalg.det", a)
|
||||
a, = promote_dtypes_inexact(a)
|
||||
a_shape = jnp.shape(a)
|
||||
a_shape = np.shape(a)
|
||||
if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
|
||||
return _det_2x2(a)
|
||||
elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3:
|
||||
@ -976,10 +976,10 @@ def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False)
|
||||
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
|
||||
# Singular values less than or equal to ``rtol * largest_singular_value``
|
||||
# are set to zero.
|
||||
rtol = lax.expand_dims(rtol[..., jnp.newaxis], range(s.ndim - rtol.ndim - 1))
|
||||
rtol = lax.expand_dims(rtol[..., np.newaxis], range(s.ndim - rtol.ndim - 1))
|
||||
cutoff = rtol * s[..., 0:1]
|
||||
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
|
||||
res = tensor_contractions.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]),
|
||||
s = jnp.where(s > cutoff, s, np.inf).astype(u.dtype)
|
||||
res = tensor_contractions.matmul(vh.mT, ufuncs.divide(u.mT, s[..., np.newaxis]),
|
||||
precision=lax.Precision.HIGHEST)
|
||||
return lax.convert_element_type(res, arr.dtype)
|
||||
|
||||
@ -1148,7 +1148,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,
|
||||
"""
|
||||
x = ensure_arraylike("jnp.linalg.norm", x)
|
||||
x, = promote_dtypes_inexact(x)
|
||||
x_shape = jnp.shape(x)
|
||||
x_shape = np.shape(x)
|
||||
ndim = len(x_shape)
|
||||
|
||||
if axis is None:
|
||||
@ -1181,12 +1181,12 @@ def norm(x: ArrayLike, ord: int | str | None = None,
|
||||
col_axis -= 1
|
||||
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
|
||||
axis=col_axis, keepdims=keepdims)
|
||||
elif ord == jnp.inf:
|
||||
elif ord == np.inf:
|
||||
if not keepdims and row_axis > col_axis:
|
||||
row_axis -= 1
|
||||
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
|
||||
axis=row_axis, keepdims=keepdims)
|
||||
elif ord == -jnp.inf:
|
||||
elif ord == -np.inf:
|
||||
if not keepdims and row_axis > col_axis:
|
||||
row_axis -= 1
|
||||
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
|
||||
@ -1392,7 +1392,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
|
||||
mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
|
||||
rank = mask.sum()
|
||||
safe_s = jnp.where(mask, s, 1).astype(a.dtype)
|
||||
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
|
||||
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, np.newaxis]
|
||||
uTb = tensor_contractions.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
|
||||
x = tensor_contractions.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
|
||||
# Numpy returns empty residuals in some cases. To allow compilation, we
|
||||
@ -1651,9 +1651,9 @@ def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, k
|
||||
if ord is None or ord == 2:
|
||||
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
|
||||
keepdims=keepdims))
|
||||
elif ord == jnp.inf:
|
||||
elif ord == np.inf:
|
||||
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == -jnp.inf:
|
||||
elif ord == -np.inf:
|
||||
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == 0:
|
||||
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
|
||||
@ -2177,7 +2177,7 @@ def cond(x: ArrayLike, p=None):
|
||||
raise ValueError(f"jnp.linalg.cond: for {p=}, array must be square; got {arr.shape=}")
|
||||
r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1))
|
||||
# Convert NaNs to infs where original array has no NaNs.
|
||||
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r)
|
||||
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), np.inf, r)
|
||||
|
||||
|
||||
@export
|
||||
|
@ -19,6 +19,8 @@ import re
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax import lax
|
||||
@ -140,7 +142,7 @@ def _check_output_dims(
|
||||
"""Check that output core dimensions match the signature."""
|
||||
def wrapped(*args):
|
||||
out = func(*args)
|
||||
out_shapes = map(jnp.shape, out if isinstance(out, tuple) else [out])
|
||||
out_shapes = map(np.shape, out if isinstance(out, tuple) else [out])
|
||||
|
||||
if expected_output_core_dims is None:
|
||||
output_core_dims = [()] * len(out_shapes)
|
||||
|
@ -98,7 +98,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
FutureWarning)
|
||||
|
||||
idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
||||
indexer = indexing.index_to_gather(jnp.shape(x), idx,
|
||||
indexer = indexing.index_to_gather(np.shape(x), idx,
|
||||
normalize_indices=normalize_indices)
|
||||
|
||||
# Avoid calling scatter if the slice shape is empty, both as a fast path and
|
||||
@ -110,7 +110,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
|
||||
# Broadcast `y` to the slice output shape.
|
||||
y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
|
||||
# Collapse any `None`/`jnp.newaxis` dimensions.
|
||||
# Collapse any `None`/`np.newaxis` dimensions.
|
||||
y = jnp.squeeze(y, axis=indexer.newaxis_dims)
|
||||
if indexer.reversed_y_dims:
|
||||
y = lax.rev(y, indexer.reversed_y_dims)
|
||||
|
@ -39,9 +39,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
array as array,
|
||||
array_equal as array_equal,
|
||||
array_equiv as array_equiv,
|
||||
array_repr as array_repr,
|
||||
array_split as array_split,
|
||||
array_str as array_str,
|
||||
astype as astype,
|
||||
asarray as asarray,
|
||||
atleast_1d as atleast_1d,
|
||||
@ -75,10 +73,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
digitize as digitize,
|
||||
dsplit as dsplit,
|
||||
dstack as dstack,
|
||||
dtype as dtype,
|
||||
e as e,
|
||||
ediff1d as ediff1d,
|
||||
euler_gamma as euler_gamma,
|
||||
expand_dims as expand_dims,
|
||||
extract as extract,
|
||||
eye as eye,
|
||||
@ -111,7 +106,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
identity as identity,
|
||||
iinfo as iinfo,
|
||||
indices as indices,
|
||||
inf as inf,
|
||||
insert as insert,
|
||||
interp as interp,
|
||||
isclose as isclose,
|
||||
@ -121,7 +115,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
isrealobj as isrealobj,
|
||||
isscalar as isscalar,
|
||||
issubdtype as issubdtype,
|
||||
iterable as iterable,
|
||||
ix_ as ix_,
|
||||
kron as kron,
|
||||
lcm as lcm,
|
||||
@ -132,17 +125,13 @@ from jax._src.numpy.lax_numpy import (
|
||||
matrix_transpose as matrix_transpose,
|
||||
meshgrid as meshgrid,
|
||||
moveaxis as moveaxis,
|
||||
nan as nan,
|
||||
nan_to_num as nan_to_num,
|
||||
nanargmax as nanargmax,
|
||||
nanargmin as nanargmin,
|
||||
ndim as ndim,
|
||||
newaxis as newaxis,
|
||||
nonzero as nonzero,
|
||||
packbits as packbits,
|
||||
pad as pad,
|
||||
permute_dims as permute_dims,
|
||||
pi as pi,
|
||||
piecewise as piecewise,
|
||||
printoptions as printoptions,
|
||||
promote_types as promote_types,
|
||||
@ -156,13 +145,9 @@ from jax._src.numpy.lax_numpy import (
|
||||
rollaxis as rollaxis,
|
||||
rot90 as rot90,
|
||||
round as round,
|
||||
save as save,
|
||||
savez as savez,
|
||||
searchsorted as searchsorted,
|
||||
select as select,
|
||||
set_printoptions as set_printoptions,
|
||||
shape as shape,
|
||||
size as size,
|
||||
split as split,
|
||||
squeeze as squeeze,
|
||||
stack as stack,
|
||||
@ -277,17 +262,32 @@ from jax._src.numpy.window_functions import (
|
||||
kaiser as kaiser,
|
||||
)
|
||||
|
||||
# NumPy generic scalar types:
|
||||
# Some APIs come directly from NumPy:
|
||||
from numpy import (
|
||||
array_repr as array_repr,
|
||||
array_str as array_str,
|
||||
character as character,
|
||||
complexfloating as complexfloating,
|
||||
dtype as dtype,
|
||||
e as e,
|
||||
euler_gamma as euler_gamma,
|
||||
flexible as flexible,
|
||||
floating as floating,
|
||||
generic as generic,
|
||||
inexact as inexact,
|
||||
inf as inf,
|
||||
integer as integer,
|
||||
iterable as iterable,
|
||||
nan as nan,
|
||||
ndim as ndim,
|
||||
newaxis as newaxis,
|
||||
number as number,
|
||||
object_ as object_,
|
||||
pi as pi,
|
||||
save as save,
|
||||
savez as savez,
|
||||
shape as shape,
|
||||
size as size,
|
||||
signedinteger as signedinteger,
|
||||
unsignedinteger as unsignedinteger,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user