mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add ensure_arraylike utility for lax.numpy implementations
This commit is contained in:
parent
994c3f59e2
commit
4c926c8d4c
@ -710,9 +710,9 @@ def trunc(x: ArrayLike) -> Array:
|
||||
[ 1., -0., 1.],
|
||||
[-8., 5., 3.]], dtype=float32)
|
||||
"""
|
||||
util.check_arraylike('trunc', x)
|
||||
x = util.ensure_arraylike('trunc', x)
|
||||
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
|
||||
return lax_internal.asarray(x)
|
||||
return x
|
||||
return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x))
|
||||
|
||||
|
||||
@ -827,8 +827,8 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *,
|
||||
>>> jnp.convolve(x1, y1)
|
||||
Array([ 3. +1.j, 11. -7.j, 15.+10.j, 7. -8.j, 31. +8.j], dtype=complex64)
|
||||
"""
|
||||
util.check_arraylike("convolve", a, v)
|
||||
return _conv(asarray(a), asarray(v), mode=mode, op='convolve',
|
||||
a, v = util.ensure_arraylike("convolve", a, v)
|
||||
return _conv(a, v, mode=mode, op='convolve',
|
||||
precision=precision, preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
@ -913,8 +913,8 @@ def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *,
|
||||
>>> jnp.correlate(x2, y2, mode='full')
|
||||
Array([ 3. +1.j, 3.+17.j, 18.+11.j, 27. +4.j, 8.-12.j], dtype=complex64)
|
||||
"""
|
||||
util.check_arraylike("correlate", a, v)
|
||||
return _conv(asarray(a), asarray(v), mode=mode, op='correlate',
|
||||
a, v = util.ensure_arraylike("correlate", a, v)
|
||||
return _conv(a, v, mode=mode, op='correlate',
|
||||
precision=precision, preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
@ -1556,8 +1556,8 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
|
||||
[[8, 7],
|
||||
[6, 5]]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("flip", m)
|
||||
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
|
||||
arr = util.ensure_arraylike("flip", m)
|
||||
return _flip(arr, reductions._ensure_optional_axes(axis))
|
||||
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
|
||||
@ -1590,8 +1590,8 @@ def fliplr(m: ArrayLike) -> Array:
|
||||
Array([[2, 1],
|
||||
[4, 3]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("fliplr", m)
|
||||
return _flip(asarray(m), 1)
|
||||
arr = util.ensure_arraylike("fliplr", m)
|
||||
return _flip(arr, 1)
|
||||
|
||||
|
||||
@export
|
||||
@ -1617,8 +1617,8 @@ def flipud(m: ArrayLike) -> Array:
|
||||
Array([[3, 4],
|
||||
[1, 2]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("flipud", m)
|
||||
return _flip(asarray(m), 0)
|
||||
arr = util.ensure_arraylike("flipud", m)
|
||||
return _flip(arr, 0)
|
||||
|
||||
|
||||
@export
|
||||
@ -1786,8 +1786,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
Array([[ 4, -3, 7, -6],
|
||||
[ 5, -1, -3, -3]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("diff", a)
|
||||
arr = asarray(a)
|
||||
arr = util.ensure_arraylike("diff", a)
|
||||
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff")
|
||||
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff")
|
||||
if n == 0:
|
||||
@ -1802,22 +1801,22 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
|
||||
combined: list[Array] = []
|
||||
if prepend is not None:
|
||||
util.check_arraylike("diff", prepend)
|
||||
prepend = util.ensure_arraylike("diff", prepend)
|
||||
if not ndim(prepend):
|
||||
shape = list(arr.shape)
|
||||
shape[axis] = 1
|
||||
prepend = broadcast_to(prepend, tuple(shape))
|
||||
combined.append(asarray(prepend))
|
||||
combined.append(prepend)
|
||||
|
||||
combined.append(arr)
|
||||
|
||||
if append is not None:
|
||||
util.check_arraylike("diff", append)
|
||||
append = util.ensure_arraylike("diff", append)
|
||||
if not ndim(append):
|
||||
shape = list(arr.shape)
|
||||
shape[axis] = 1
|
||||
append = broadcast_to(append, tuple(shape))
|
||||
combined.append(asarray(append))
|
||||
combined.append(append)
|
||||
|
||||
if len(combined) > 1:
|
||||
arr = concatenate(combined, axis)
|
||||
@ -1888,15 +1887,14 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None,
|
||||
>>> jnp.ediff1d(a2)
|
||||
Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("ediff1d", ary)
|
||||
arr = ravel(ary)
|
||||
arr = util.ensure_arraylike("ediff1d", ary).ravel()
|
||||
result = lax.sub(arr[1:], arr[:-1])
|
||||
if to_begin is not None:
|
||||
util.check_arraylike("ediff1d", to_begin)
|
||||
result = concatenate((ravel(asarray(to_begin, dtype=arr.dtype)), result))
|
||||
to_begin = util.ensure_arraylike("ediff1d", to_begin)
|
||||
result = concatenate((ravel(to_begin.astype(arr.dtype)), result))
|
||||
if to_end is not None:
|
||||
util.check_arraylike("ediff1d", to_end)
|
||||
result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype))))
|
||||
to_end = util.ensure_arraylike("ediff1d", to_end)
|
||||
result = concatenate((result, ravel(to_end.astype(arr.dtype))))
|
||||
return result
|
||||
|
||||
|
||||
@ -2350,8 +2348,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
|
||||
>>> jnp.ravel_multi_index(indices_2D, shape)
|
||||
Array([1, 3, 5], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("unravel_index", indices)
|
||||
indices_arr = asarray(indices)
|
||||
indices_arr = util.ensure_arraylike("unravel_index", indices)
|
||||
# Note: we do not convert shape to an array, because it may be passed as a
|
||||
# tuple of weakly-typed values, and asarray() would strip these weak types.
|
||||
try:
|
||||
@ -2480,8 +2477,8 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
|
||||
>>> x.squeeze()
|
||||
Array([0, 1, 2], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("squeeze", a)
|
||||
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
|
||||
arr = util.ensure_arraylike("squeeze", a)
|
||||
return _squeeze(arr, _ensure_index_tuple(axis) if axis is not None else None)
|
||||
|
||||
@partial(jit, static_argnames=('axis',), inline=True)
|
||||
def _squeeze(a: Array, axis: tuple[int, ...]) -> Array:
|
||||
@ -2662,8 +2659,8 @@ def moveaxis(a: ArrayLike, source: int | Sequence[int],
|
||||
>>> a.transpose(2, 3, 1, 0).shape
|
||||
(4, 5, 3, 2)
|
||||
"""
|
||||
util.check_arraylike("moveaxis", a)
|
||||
return _moveaxis(asarray(a), _ensure_index_tuple(source),
|
||||
arr = util.ensure_arraylike("moveaxis", a)
|
||||
return _moveaxis(arr, _ensure_index_tuple(source),
|
||||
_ensure_index_tuple(destination))
|
||||
|
||||
@partial(jit, static_argnames=('source', 'destination'), inline=True)
|
||||
@ -3266,8 +3263,7 @@ def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array:
|
||||
def _split(op: str, ary: ArrayLike,
|
||||
indices_or_sections: int | Sequence[int] | ArrayLike,
|
||||
axis: int = 0) -> list[Array]:
|
||||
util.check_arraylike(op, ary)
|
||||
ary = asarray(ary)
|
||||
ary = util.ensure_arraylike(op, ary)
|
||||
axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`")
|
||||
size = ary.shape[axis]
|
||||
if (isinstance(indices_or_sections, (tuple, list)) or
|
||||
@ -3430,8 +3426,7 @@ def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike)
|
||||
- :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections``
|
||||
to be an integer that does not evenly divide the size of the array.
|
||||
"""
|
||||
util.check_arraylike("hsplit", ary)
|
||||
a = asarray(ary)
|
||||
a = util.ensure_arraylike("hsplit", ary)
|
||||
return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1)
|
||||
|
||||
|
||||
@ -3616,7 +3611,7 @@ def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
|
||||
>>> jnp.round(x1)
|
||||
Array([10., 22., 12., 32.], dtype=float32)
|
||||
"""
|
||||
util.check_arraylike("round", a)
|
||||
a = util.ensure_arraylike("round", a)
|
||||
decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.round is not supported.")
|
||||
@ -3625,7 +3620,7 @@ def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
|
||||
if decimals < 0:
|
||||
raise NotImplementedError(
|
||||
"integer np.round not implemented for decimals < 0")
|
||||
return asarray(a) # no-op on integer types
|
||||
return a # no-op on integer types
|
||||
|
||||
def _round_float(x: ArrayLike) -> Array:
|
||||
if decimals == 0:
|
||||
@ -3742,10 +3737,10 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
|
||||
Array([ 0., 0., 1., inf, 2., -inf], dtype=float32)
|
||||
"""
|
||||
del copy
|
||||
util.check_arraylike("nan_to_num", x)
|
||||
x = util.ensure_arraylike("nan_to_num", x)
|
||||
dtype = _dtype(x)
|
||||
if not issubdtype(dtype, inexact):
|
||||
return asarray(x)
|
||||
return x
|
||||
if issubdtype(dtype, complexfloating):
|
||||
return lax.complex(
|
||||
nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf),
|
||||
@ -3890,8 +3885,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
|
||||
>>> nonzero_jit(x, size=5, fill_value=len(x))
|
||||
(Array([1, 3, 5, 6, 6], dtype=int32),)
|
||||
"""
|
||||
util.check_arraylike("nonzero", a)
|
||||
arr = asarray(a)
|
||||
arr = util.ensure_arraylike("nonzero", a)
|
||||
del a
|
||||
if ndim(arr) == 0:
|
||||
raise ValueError("Calling nonzero on 0d arrays is not allowed. "
|
||||
@ -4020,8 +4014,7 @@ def unwrap(p: ArrayLike, discont: ArrayLike | None = None,
|
||||
a larger discontinuity it adds factors of the period to the data. For periodic signals
|
||||
that satisfy this assumption, :func:`unwrap` can recover the original phased signal.
|
||||
"""
|
||||
util.check_arraylike("unwrap", p)
|
||||
p = asarray(p)
|
||||
p = util.ensure_arraylike("unwrap", p)
|
||||
if issubdtype(p.dtype, np.complexfloating):
|
||||
raise ValueError("jnp.unwrap does not support complex inputs.")
|
||||
if p.shape[axis] == 0:
|
||||
@ -4648,8 +4641,7 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
|
||||
Array([[1, 2, 3],
|
||||
[4, 5, 6]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("unstack", x)
|
||||
x = asarray(x)
|
||||
x = util.ensure_arraylike("unstack", x)
|
||||
if x.ndim == 0:
|
||||
raise ValueError(
|
||||
"Unstack requires arrays with rank > 0, however a scalar array was "
|
||||
@ -5712,8 +5704,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
|
||||
>>> y.astype(int) # truncates fractional values
|
||||
Array([0, 0, 1], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("astype", x)
|
||||
x_arr = asarray(x)
|
||||
x_arr = util.ensure_arraylike("astype", x)
|
||||
|
||||
if dtype is None:
|
||||
dtype = dtypes.canonicalize_dtype(float_)
|
||||
@ -6642,8 +6633,7 @@ def _eye(N: DimSize, M: DimSize | None = None,
|
||||
if isinstance(k, int):
|
||||
k = lax_internal._clip_int_to_valid_range(k, np.int32,
|
||||
"`argument `k` of jax.numpy.eye")
|
||||
util.check_arraylike("eye", k)
|
||||
offset = asarray(k)
|
||||
offset = util.ensure_arraylike("eye", k)
|
||||
if not (offset.shape == () and dtypes.issubdtype(offset.dtype, np.integer)):
|
||||
raise ValueError(f"k must be a scalar integer; got {k}")
|
||||
N_int = core.canonicalize_dim(N, "argument of 'N' jnp.eye()")
|
||||
@ -6935,14 +6925,14 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
dtypes.check_user_dtype_supported(dtype, "linspace")
|
||||
if num < 0:
|
||||
raise ValueError(f"Number of samples, {num}, must be non-negative.")
|
||||
util.check_arraylike("linspace", start, stop)
|
||||
start, stop = util.ensure_arraylike("linspace", start, stop)
|
||||
|
||||
if dtype is None:
|
||||
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
computation_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
start = asarray(start, dtype=computation_dtype)
|
||||
stop = asarray(stop, dtype=computation_dtype)
|
||||
start = start.astype(computation_dtype)
|
||||
stop = stop.astype(computation_dtype)
|
||||
|
||||
bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
|
||||
broadcast_start = broadcast_to(start, bounds_shape)
|
||||
@ -7061,9 +7051,9 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
computation_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
util.check_arraylike("logspace", start, stop)
|
||||
start = asarray(start, dtype=computation_dtype)
|
||||
stop = asarray(stop, dtype=computation_dtype)
|
||||
start, stop = util.ensure_arraylike("logspace", start, stop)
|
||||
start = start.astype(computation_dtype)
|
||||
stop = stop.astype(computation_dtype)
|
||||
lin = linspace(start, stop, num,
|
||||
endpoint=endpoint, retstep=False, dtype=None, axis=axis)
|
||||
return lax.convert_element_type(ufuncs.power(base, lin), dtype)
|
||||
@ -7131,9 +7121,9 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
|
||||
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
computation_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
util.check_arraylike("geomspace", start, stop)
|
||||
start = asarray(start, dtype=computation_dtype)
|
||||
stop = asarray(stop, dtype=computation_dtype)
|
||||
start, stop = util.ensure_arraylike("geomspace", start, stop)
|
||||
start = start.astype(computation_dtype)
|
||||
stop = stop.astype(computation_dtype)
|
||||
|
||||
sign = ufuncs.sign(start)
|
||||
res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign),
|
||||
@ -7207,8 +7197,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
|
||||
[[10 20 30]
|
||||
[10 20 30]]
|
||||
"""
|
||||
util.check_arraylike("meshgrid", *xi)
|
||||
args = [asarray(x) for x in xi]
|
||||
args = list(util.ensure_arraylike_tuple("meshgrid", tuple(xi)))
|
||||
if not copy:
|
||||
raise ValueError("jax.numpy.meshgrid only supports copy=True")
|
||||
if indexing not in ["xy", "ij"]:
|
||||
@ -7310,11 +7299,10 @@ def ix_(*args: ArrayLike) -> tuple[Array, ...]:
|
||||
Array([[ 20, 40],
|
||||
[100, 120]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("ix", *args)
|
||||
args = util.ensure_arraylike_tuple("ix", args)
|
||||
n = len(args)
|
||||
output = []
|
||||
for i, a in enumerate(args):
|
||||
a = asarray(a)
|
||||
if len(a.shape) != 1:
|
||||
msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}"
|
||||
raise ValueError(msg.format(a.shape))
|
||||
@ -7457,14 +7445,12 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
||||
Array([[1, 1, 2, 2, 2, 2, 2],
|
||||
[3, 3, 4, 4, 4, 4, 4]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("repeat", a)
|
||||
arr = util.ensure_arraylike("repeat", a)
|
||||
core.is_dim(repeats) or util.check_arraylike("repeat", repeats)
|
||||
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
arr = arr.ravel()
|
||||
axis = 0
|
||||
else:
|
||||
a = asarray(a)
|
||||
|
||||
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()")
|
||||
assert isinstance(axis, int) # to appease mypy
|
||||
@ -7482,44 +7468,44 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
||||
"value to `total_repeat_length`.")
|
||||
|
||||
# Fast path for when repeats is a scalar.
|
||||
if np.ndim(repeats) == 0 and ndim(a) != 0:
|
||||
input_shape = shape(a)
|
||||
if np.ndim(repeats) == 0 and ndim(arr) != 0:
|
||||
input_shape = arr.shape
|
||||
axis = _canonicalize_axis(axis, len(input_shape))
|
||||
aux_axis = axis + 1
|
||||
aux_shape: list[DimSize] = list(input_shape)
|
||||
aux_shape.insert(aux_axis, operator.index(repeats) if core.is_constant_dim(repeats) else repeats) # type: ignore
|
||||
a = lax.broadcast_in_dim(
|
||||
a, aux_shape, [i for i in range(len(aux_shape)) if i != aux_axis])
|
||||
arr = lax.broadcast_in_dim(
|
||||
arr, aux_shape, [i for i in range(len(aux_shape)) if i != aux_axis])
|
||||
result_shape: list[DimSize] = list(input_shape)
|
||||
result_shape[axis] *= repeats
|
||||
return reshape(a, result_shape)
|
||||
return arr.reshape(result_shape)
|
||||
|
||||
repeats = np.ravel(repeats)
|
||||
if ndim(a) != 0:
|
||||
repeats = np.broadcast_to(repeats, [shape(a)[axis]])
|
||||
if arr.ndim != 0:
|
||||
repeats = np.broadcast_to(repeats, [arr.shape[axis]])
|
||||
total_repeat_length = np.sum(repeats)
|
||||
else:
|
||||
repeats = ravel(repeats)
|
||||
if ndim(a) != 0:
|
||||
repeats = broadcast_to(repeats, [shape(a)[axis]])
|
||||
if arr.ndim != 0:
|
||||
repeats = broadcast_to(repeats, [arr.shape[axis]])
|
||||
|
||||
# Special case when a is a scalar.
|
||||
if ndim(a) == 0:
|
||||
if arr.ndim == 0:
|
||||
if shape(repeats) == (1,):
|
||||
return full([total_repeat_length], a)
|
||||
return full([total_repeat_length], arr)
|
||||
else:
|
||||
raise ValueError('`repeat` with a scalar parameter `a` is only '
|
||||
'implemented for scalar values of the parameter `repeats`.')
|
||||
|
||||
# Special case if total_repeat_length is zero.
|
||||
if total_repeat_length == 0:
|
||||
result_shape = list(shape(a))
|
||||
result_shape = list(arr.shape)
|
||||
result_shape[axis] = 0
|
||||
return reshape(array([], dtype=_dtype(a)), result_shape)
|
||||
return reshape(array([], dtype=arr.dtype), result_shape)
|
||||
|
||||
# If repeats is on a zero sized axis, then return the array.
|
||||
if shape(a)[axis] == 0:
|
||||
return asarray(a)
|
||||
if arr.shape[axis] == 0:
|
||||
return arr
|
||||
|
||||
# This implementation of repeat avoid having to instantiate a large.
|
||||
# intermediate tensor.
|
||||
@ -7533,7 +7519,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
||||
block_split_indicators = block_split_indicators.at[scatter_indices].add(1)
|
||||
# Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3]
|
||||
gather_indices = reductions.cumsum(block_split_indicators) - 1
|
||||
return take(a, gather_indices, axis=axis)
|
||||
return take(arr, gather_indices, axis=axis)
|
||||
|
||||
|
||||
@export
|
||||
@ -8213,9 +8199,7 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *,
|
||||
raise NotImplementedError("JAX arrays are immutable, must use inplace=False")
|
||||
if wrap:
|
||||
raise NotImplementedError("wrap=True is not implemented, must use wrap=False")
|
||||
util.check_arraylike("fill_diagonal", a, val)
|
||||
a = asarray(a)
|
||||
val = asarray(val)
|
||||
a, val = util.ensure_arraylike("fill_diagonal", a, val)
|
||||
if a.ndim < 2:
|
||||
raise ValueError("array must be at least 2-d")
|
||||
if a.ndim > 2 and not all(n == a.shape[0] for n in a.shape[1:]):
|
||||
@ -8685,11 +8669,10 @@ def delete(
|
||||
>>> jit_delete(a, indices, assume_unique_indices=True)
|
||||
Array([6, 8, 9], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("delete", arr)
|
||||
a = util.ensure_arraylike("delete", arr)
|
||||
if axis is None:
|
||||
arr = ravel(arr)
|
||||
a = a.ravel()
|
||||
axis = 0
|
||||
a = asarray(arr)
|
||||
axis = _canonicalize_axis(axis, a.ndim)
|
||||
|
||||
# Case 1: obj is a static integer.
|
||||
@ -8788,9 +8771,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike,
|
||||
Array([[ 1, 10, 2, 3, 11],
|
||||
[ 4, 12, 5, 6, 13]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
|
||||
a = asarray(arr)
|
||||
values_arr = asarray(values)
|
||||
a, _, values_arr = util.ensure_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
|
||||
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
@ -8960,8 +8941,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,
|
||||
>>> jnp.prod(x, [0, 1], keepdims=True)
|
||||
Array([[720]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("apply_over_axes", a)
|
||||
a_arr = asarray(a)
|
||||
a_arr = util.ensure_arraylike("apply_over_axes", a)
|
||||
for axis in axes:
|
||||
b = func(a_arr, axis)
|
||||
if b.ndim == a_arr.ndim:
|
||||
@ -9041,9 +9021,8 @@ def dot(a: ArrayLike, b: ArrayLike, *,
|
||||
>>> jnp.matmul(a, b).shape
|
||||
(3, 2, 1)
|
||||
"""
|
||||
util.check_arraylike("dot", a, b)
|
||||
a, b = util.ensure_arraylike("dot", a, b)
|
||||
dtypes.check_user_dtype_supported(preferred_element_type, "dot")
|
||||
a, b = asarray(a), asarray(b)
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True)
|
||||
else:
|
||||
@ -9124,9 +9103,8 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
|
||||
Array([[22, 28],
|
||||
[49, 64]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("matmul", a, b)
|
||||
a, b = util.ensure_arraylike("matmul", a, b)
|
||||
dtypes.check_user_dtype_supported(preferred_element_type, "matmul")
|
||||
a, b = asarray(a), asarray(b)
|
||||
for i, x in enumerate((a, b)):
|
||||
if ndim(x) < 1:
|
||||
msg = (f"matmul input operand {i} must have ndim at least 1, "
|
||||
@ -9368,8 +9346,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
|
||||
>>> jnp.linalg.vecdot(a, b, axis=-1)
|
||||
Array([20, 47], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("jnp.vecdot", x1, x2)
|
||||
x1_arr, x2_arr = asarray(x1), asarray(x2)
|
||||
x1_arr, x2_arr = util.ensure_arraylike("jnp.vecdot", x1, x2)
|
||||
if x1_arr.shape[axis] != x2_arr.shape[axis]:
|
||||
raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}")
|
||||
x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1)
|
||||
@ -9454,9 +9431,8 @@ def tensordot(a: ArrayLike, b: ArrayLike,
|
||||
Array([[1, 2, 3],
|
||||
[2, 4, 6]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("tensordot", a, b)
|
||||
a, b = util.ensure_arraylike("tensordot", a, b)
|
||||
dtypes.check_user_dtype_supported(preferred_element_type, "tensordot")
|
||||
a, b = asarray(a), asarray(b)
|
||||
a_ndim = ndim(a)
|
||||
b_ndim = ndim(b)
|
||||
|
||||
@ -10083,7 +10059,7 @@ def inner(
|
||||
>>> jnp.inner(a, b).shape
|
||||
(2, 5)
|
||||
"""
|
||||
util.check_arraylike("inner", a, b)
|
||||
a, b = util.ensure_arraylike("inner", a, b)
|
||||
if ndim(a) == 0 or ndim(b) == 0:
|
||||
a = asarray(a, dtype=preferred_element_type)
|
||||
b = asarray(b, dtype=preferred_element_type)
|
||||
@ -10320,8 +10296,7 @@ def vander(
|
||||
[ 1, 3, 9, 27],
|
||||
[ 1, 4, 16, 64]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("vander", x)
|
||||
x = asarray(x)
|
||||
x = util.ensure_arraylike("vander", x)
|
||||
if x.ndim != 1:
|
||||
raise ValueError("x must be a one-dimensional array")
|
||||
N = x.shape[0] if N is None else core.concrete_or_error(
|
||||
@ -10440,10 +10415,10 @@ def argmax(a: ArrayLike, axis: int | None = None, out: None = None,
|
||||
Array([[1],
|
||||
[0]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("argmax", a)
|
||||
arr = util.ensure_arraylike("argmax", a)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.")
|
||||
return _argmax(asarray(a), None if axis is None else operator.index(axis),
|
||||
return _argmax(arr, None if axis is None else operator.index(axis),
|
||||
keepdims=bool(keepdims))
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
@ -10496,10 +10471,10 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None,
|
||||
Array([[0],
|
||||
[2]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("argmin", a)
|
||||
arr = util.ensure_arraylike("argmin", a)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.")
|
||||
return _argmin(asarray(a), None if axis is None else operator.index(axis),
|
||||
return _argmin(arr, None if axis is None else operator.index(axis),
|
||||
keepdims=bool(keepdims))
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
@ -10693,17 +10668,15 @@ def sort(
|
||||
- :func:`jax.numpy.lexsort`: lexicographical sort of multiple arrays.
|
||||
- :func:`jax.lax.sort`: lower-level function wrapping XLA's Sort operator.
|
||||
"""
|
||||
util.check_arraylike("sort", a)
|
||||
arr = util.ensure_arraylike("sort", a)
|
||||
if kind is not None:
|
||||
raise TypeError("'kind' argument to sort is not supported. Use"
|
||||
" stable=True or stable=False to specify sort stability.")
|
||||
if order is not None:
|
||||
raise TypeError("'order' argument to sort is not supported.")
|
||||
if axis is None:
|
||||
arr = ravel(a)
|
||||
arr = arr.ravel()
|
||||
axis = 0
|
||||
else:
|
||||
arr = asarray(a)
|
||||
dimension = _canonicalize_axis(axis, arr.ndim)
|
||||
result = lax.sort(arr, dimension=dimension, is_stable=stable)
|
||||
return lax.rev(result, dimensions=[dimension]) if descending else result
|
||||
@ -10742,8 +10715,8 @@ def sort_complex(a: ArrayLike) -> Array:
|
||||
Array([[3.+0.j, 4.+0.j, 5.+0.j],
|
||||
[2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64)
|
||||
"""
|
||||
util.check_arraylike("sort_complex", a)
|
||||
a = lax.sort(asarray(a))
|
||||
a = util.ensure_arraylike("sort_complex", a)
|
||||
a = lax.sort(a)
|
||||
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
|
||||
|
||||
|
||||
@ -10810,9 +10783,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A
|
||||
Array([[0, 1, 0, 1],
|
||||
[1, 0, 1, 0]], dtype=int32)
|
||||
"""
|
||||
key_tuple = tuple(keys)
|
||||
util.check_arraylike("lexsort", *key_tuple)
|
||||
key_arrays = tuple(asarray(k) for k in key_tuple)
|
||||
key_arrays = util.ensure_arraylike_tuple("lexsort", tuple(keys))
|
||||
if len(key_arrays) == 0:
|
||||
raise TypeError("need sequence of keys with len > 0 in lexsort")
|
||||
if len({shape(key) for key in key_arrays}) > 1:
|
||||
@ -10881,18 +10852,15 @@ def argsort(
|
||||
- :func:`jax.numpy.lexsort`: lexicographical sort of multiple arrays.
|
||||
- :func:`jax.lax.sort`: lower-level function wrapping XLA's Sort operator.
|
||||
"""
|
||||
util.check_arraylike("argsort", a)
|
||||
arr = asarray(a)
|
||||
arr = util.ensure_arraylike("argsort", a)
|
||||
if kind is not None:
|
||||
raise TypeError("'kind' argument to argsort is not supported. Use"
|
||||
" stable=True or stable=False to specify sort stability.")
|
||||
if order is not None:
|
||||
raise TypeError("'order' argument to argsort is not supported.")
|
||||
if axis is None:
|
||||
arr = ravel(arr)
|
||||
arr = arr.ravel()
|
||||
axis = 0
|
||||
else:
|
||||
arr = asarray(a)
|
||||
dimension = _canonicalize_axis(axis, arr.ndim)
|
||||
use_64bit_index = not core.is_constant_dim(arr.shape[dimension]) or arr.shape[dimension] >= (1 << 31)
|
||||
iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, dimension)
|
||||
@ -10959,8 +10927,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
|
||||
order is arbitrary and implementation-dependent.
|
||||
"""
|
||||
# TODO(jakevdp): handle NaN values like numpy.
|
||||
util.check_arraylike("partition", a)
|
||||
arr = asarray(a)
|
||||
arr = util.ensure_arraylike("partition", a)
|
||||
if issubdtype(arr.dtype, np.complexfloating):
|
||||
raise NotImplementedError("jnp.partition for complex dtype is not implemented.")
|
||||
axis = _canonicalize_axis(axis, arr.ndim)
|
||||
@ -11031,8 +10998,7 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
|
||||
order is arbitrary and implementation-dependent.
|
||||
"""
|
||||
# TODO(jakevdp): handle NaN values like numpy.
|
||||
util.check_arraylike("partition", a)
|
||||
arr = asarray(a)
|
||||
arr = util.ensure_arraylike("partition", a)
|
||||
if issubdtype(arr.dtype, np.complexfloating):
|
||||
raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.")
|
||||
axis = _canonicalize_axis(axis, arr.ndim)
|
||||
@ -11123,8 +11089,7 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int],
|
||||
[ 9, 10, 11, 8],
|
||||
[ 1, 2, 3, 0]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("roll", a)
|
||||
arr = asarray(a)
|
||||
arr = util.ensure_arraylike("roll", a)
|
||||
if axis is None:
|
||||
return roll(arr.ravel(), shift, 0).reshape(arr.shape)
|
||||
axis = _ensure_index_tuple(axis)
|
||||
@ -11262,8 +11227,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar
|
||||
Array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
|
||||
[0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]], dtype=uint8)
|
||||
"""
|
||||
util.check_arraylike("packbits", a)
|
||||
arr = asarray(a)
|
||||
arr = util.ensure_arraylike("packbits", a)
|
||||
if not (issubdtype(arr.dtype, integer) or issubdtype(arr.dtype, bool_)):
|
||||
raise TypeError('Expected an input array of integer or boolean data type')
|
||||
if bitorder not in ['little', 'big']:
|
||||
@ -11357,9 +11321,8 @@ def unpackbits(
|
||||
>>> jnp.unpackbits(vals, count=-5) # specify 5 bits to be trimmed
|
||||
Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8)
|
||||
"""
|
||||
util.check_arraylike("unpackbits", a)
|
||||
arr = asarray(a)
|
||||
if _dtype(a) != uint8:
|
||||
arr = util.ensure_arraylike("unpackbits", a)
|
||||
if arr.dtype != uint8:
|
||||
raise TypeError("Expected an input array of unsigned byte data type")
|
||||
if bitorder not in ['little', 'big']:
|
||||
raise ValueError("'order' must be either 'little' or 'big'")
|
||||
@ -11473,9 +11436,7 @@ def _take(a, indices, axis: int | None = None, out=None, mode=None,
|
||||
unique_indices=False, indices_are_sorted=False, fill_value=None):
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
|
||||
util.check_arraylike("take", a, indices)
|
||||
a = asarray(a)
|
||||
indices = asarray(indices)
|
||||
a, indices = util.ensure_arraylike("take", a, indices)
|
||||
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
@ -11618,8 +11579,7 @@ def take_along_axis(
|
||||
Array([[3],
|
||||
[2]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("take_along_axis", arr, indices)
|
||||
a = asarray(arr)
|
||||
a, indices = util.ensure_arraylike("take_along_axis", arr, indices)
|
||||
index_dtype = dtypes.dtype(indices)
|
||||
idx_shape = shape(indices)
|
||||
if not dtypes.issubdtype(index_dtype, integer):
|
||||
@ -11791,10 +11751,7 @@ def put_along_axis(
|
||||
"jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays"
|
||||
"are immutable. Pass inplace=False to instead return an updated array.")
|
||||
|
||||
util.check_arraylike("put_along_axis", arr, indices, values)
|
||||
arr = asarray(arr)
|
||||
indices = asarray(indices)
|
||||
values = asarray(values)
|
||||
arr, indices, values = util.ensure_arraylike("put_along_axis", arr, indices, values)
|
||||
|
||||
original_axis = axis
|
||||
original_arr_shape = arr.shape
|
||||
@ -12814,17 +12771,17 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None,
|
||||
[ 6, 0, 0, 0],
|
||||
[ 9, 12, 0, 0]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("compress", condition, a, fill_value)
|
||||
condition_arr = asarray(condition).astype(bool)
|
||||
condition_arr, arr, fill_value = util.ensure_arraylike("compress", condition, a, fill_value)
|
||||
condition_arr = condition_arr.astype(bool)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.compress is not supported.")
|
||||
if condition_arr.ndim != 1:
|
||||
raise ValueError("condition must be a 1D array")
|
||||
if axis is None:
|
||||
axis = 0
|
||||
arr = ravel(a)
|
||||
arr = ravel(arr)
|
||||
else:
|
||||
arr = moveaxis(a, axis, 0)
|
||||
arr = moveaxis(arr, axis, 0)
|
||||
condition_arr, extra = condition_arr[:arr.shape[0]], condition_arr[arr.shape[0]:]
|
||||
arr = arr[:condition_arr.shape[0]]
|
||||
|
||||
@ -12965,7 +12922,7 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True,
|
||||
|
||||
w: Array | None = None
|
||||
if fweights is not None:
|
||||
util.check_arraylike("cov", fweights)
|
||||
fweights = util.ensure_arraylike("cov", fweights)
|
||||
if ndim(fweights) > 1:
|
||||
raise RuntimeError("cannot handle multidimensional fweights")
|
||||
if shape(fweights)[0] != X.shape[1]:
|
||||
@ -12973,16 +12930,16 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True,
|
||||
if not issubdtype(_dtype(fweights), integer):
|
||||
raise TypeError("fweights must be integer.")
|
||||
# Ensure positive fweights; note that numpy raises an error on negative fweights.
|
||||
w = asarray(ufuncs.abs(fweights))
|
||||
w = abs(fweights)
|
||||
if aweights is not None:
|
||||
util.check_arraylike("cov", aweights)
|
||||
aweights = util.ensure_arraylike("cov", aweights)
|
||||
if ndim(aweights) > 1:
|
||||
raise RuntimeError("cannot handle multidimensional aweights")
|
||||
if shape(aweights)[0] != X.shape[1]:
|
||||
raise RuntimeError("incompatible numbers of samples and aweights")
|
||||
# Ensure positive aweights: note that numpy raises an error for negative aweights.
|
||||
aweights = ufuncs.abs(aweights)
|
||||
w = asarray(aweights) if w is None else w * asarray(aweights)
|
||||
aweights = abs(aweights)
|
||||
w = aweights if w is None else w * aweights
|
||||
|
||||
avg, w_sum = reductions.average(X, axis=1, weights=w, returned=True)
|
||||
w_sum = w_sum[0]
|
||||
@ -13218,7 +13175,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
|
||||
'sort': _searchsorted_via_sort,
|
||||
'compare_all': _searchsorted_via_compare_all,
|
||||
}[method]
|
||||
return impl(asarray(a), asarray(v), side, dtype) # type: ignore
|
||||
return impl(a, v, side, dtype) # type: ignore
|
||||
|
||||
|
||||
@export
|
||||
@ -13261,9 +13218,8 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False,
|
||||
>>> jnp.digitize(x, bins)
|
||||
Array([2, 1, 1, 2, 0, 0], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("digitize", x, bins)
|
||||
x, bins_arr = util.ensure_arraylike("digitize", x, bins)
|
||||
right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()")
|
||||
bins_arr = asarray(bins)
|
||||
if bins_arr.ndim != 1:
|
||||
raise ValueError(f"digitize: bins must be a 1-dimensional array; got {bins=}")
|
||||
if bins_arr.shape[0] == 0:
|
||||
@ -13347,7 +13303,7 @@ def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike],
|
||||
>>> jnp.piecewise(x, condlist, funclist)
|
||||
Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("piecewise", x)
|
||||
x_arr = util.ensure_arraylike("piecewise", x)
|
||||
nc, nf = len(condlist), len(funclist)
|
||||
if nf == nc + 1:
|
||||
funclist = funclist[-1:] + funclist[:-1]
|
||||
@ -13357,7 +13313,7 @@ def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike],
|
||||
raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}")
|
||||
consts = {i: c for i, c in enumerate(funclist) if not callable(c)}
|
||||
funcs = {i: f for i, f in enumerate(funclist) if callable(f)}
|
||||
return _piecewise(asarray(x), asarray(condlist, dtype=bool_), consts,
|
||||
return _piecewise(x_arr, asarray(condlist, dtype=bool_), consts,
|
||||
frozenset(funcs.items()), # dict is not hashable.
|
||||
*args, **kw)
|
||||
|
||||
@ -13444,7 +13400,8 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *,
|
||||
[0, 5, 0, 0, 1],
|
||||
[0, 0, 3, 0, 0]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("place", arr, mask, vals)
|
||||
data, mask_arr, vals_arr = util.ensure_arraylike("place", arr, mask, vals)
|
||||
vals_arr = vals_arr.ravel()
|
||||
data, mask_arr, vals_arr = asarray(arr), asarray(mask), ravel(vals)
|
||||
if inplace:
|
||||
raise ValueError(
|
||||
@ -13526,8 +13483,9 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
|
||||
[ 0, 0, 20, 0, 0],
|
||||
[ 0, 0, 0, 0, 30]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("put", a, ind, v)
|
||||
arr, ind_arr, v_arr = asarray(a), ravel(ind), ravel(v)
|
||||
arr, ind_arr, _ = util.ensure_arraylike("put", a, ind, v)
|
||||
ind_arr = ind_arr.ravel()
|
||||
v_arr = ravel(v)
|
||||
if not arr.size or not ind_arr.size or not v_arr.size:
|
||||
return arr
|
||||
v_arr = _tile_to_size(v_arr, len(ind_arr))
|
||||
|
@ -15,7 +15,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, overload
|
||||
|
||||
import warnings
|
||||
|
||||
@ -133,6 +133,42 @@ def _arraylike(x: ArrayLike) -> bool:
|
||||
hasattr(x, '__jax_array__') or np.isscalar(x))
|
||||
|
||||
|
||||
def _arraylike_asarray(x: Any) -> Array:
|
||||
"""Convert an array-like object to an array."""
|
||||
if hasattr(x, '__jax_array__'):
|
||||
x = x.__jax_array__()
|
||||
elif isinstance(x, (bool, int, float, complex)):
|
||||
x = dtypes.coerce_to_array(x)
|
||||
return lax.asarray(x)
|
||||
|
||||
|
||||
@overload
|
||||
def ensure_arraylike(fun_name: str, /) -> tuple[()]: ...
|
||||
@overload
|
||||
def ensure_arraylike(fun_name: str, a1: Any, /) -> Array: ...
|
||||
@overload
|
||||
def ensure_arraylike(fun_name: str, a1: Any, a2: Any, /) -> tuple[Array, Array]: ...
|
||||
@overload
|
||||
def ensure_arraylike(fun_name: str, a1: Any, a2: Any, a3: Any, /) -> tuple[Array, Array, Array]: ...
|
||||
@overload
|
||||
def ensure_arraylike(fun_name: str, a1: Any, a2: Any, a3: Any, a4: Any, /, *args: Any) -> tuple[Array, ...]: ...
|
||||
def ensure_arraylike(fun_name: str, /, *args: Any) -> Array | tuple[Array, ...]:
|
||||
"""Check that arguments are arraylike and convert them to arrays."""
|
||||
check_arraylike(fun_name, *args)
|
||||
if len(args) == 1:
|
||||
return _arraylike_asarray(args[0]) # pytype: disable=bad-return-type
|
||||
return tuple(_arraylike_asarray(arg) for arg in args) # pytype: disable=bad-return-type
|
||||
|
||||
|
||||
def ensure_arraylike_tuple(fun_name: str, tup: tuple[Any, ...]) -> tuple[Array, ...]:
|
||||
"""Check that argument elements are arraylike and convert to a tuple of arrays.
|
||||
|
||||
This is useful because ensure_arraylike with a single argument returns a single array.
|
||||
"""
|
||||
check_arraylike(fun_name, *tup)
|
||||
return tuple(_arraylike_asarray(arg) for arg in tup)
|
||||
|
||||
|
||||
def check_arraylike(fun_name: str, *args: Any, emit_warning=False, stacklevel=3):
|
||||
"""Check if all args fit JAX's definition of arraylike."""
|
||||
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user