mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
refactor: move lax_numpy indexing routines to their own submodule
This commit is contained in:
parent
5ebb7eb55d
commit
f750d0b855
@ -383,7 +383,7 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
from jax._src.lax import lax
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.numpy import indexing
|
||||
self._check_if_deleted()
|
||||
|
||||
if isinstance(self.sharding, PmapSharding):
|
||||
@ -418,7 +418,7 @@ class ArrayImpl(basearray.Array):
|
||||
return ArrayImpl(
|
||||
out.aval, sharding, [out], committed=False, _skip_checks=True)
|
||||
|
||||
return lax_numpy._rewriting_take(self, idx)
|
||||
return indexing.rewriting_take(self, idx)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ndim == 0:
|
||||
|
@ -40,6 +40,7 @@ from jax._src.array import ArrayImpl
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy import array_api_metadata
|
||||
from jax._src.numpy import indexing
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.numpy import tensor_contractions
|
||||
from jax._src import mesh as mesh_lib
|
||||
@ -382,8 +383,8 @@ def _take(self: Array, indices: ArrayLike, axis: int | None = None, out: None =
|
||||
|
||||
Refer to :func:`jax.numpy.take` for full documentation.
|
||||
"""
|
||||
return lax_numpy.take(self, indices, axis=axis, out=out, mode=mode, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted, fill_value=fill_value)
|
||||
return indexing.take(self, indices, axis=axis, out=out, mode=mode, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted, fill_value=fill_value)
|
||||
|
||||
def _to_device(self: Array, device: xc.Device | Sharding, *,
|
||||
stream: int | Any | None = None):
|
||||
@ -649,7 +650,7 @@ def _chunk_iter(x, size):
|
||||
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
|
||||
|
||||
def _getitem(self, item):
|
||||
return lax_numpy._rewriting_take(self, item)
|
||||
return indexing.rewriting_take(self, item)
|
||||
|
||||
# Syntactic sugar for scatter operations.
|
||||
class _IndexUpdateHelper:
|
||||
@ -777,7 +778,7 @@ class _IndexUpdateRef:
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
take = partial(lax_numpy._rewriting_take,
|
||||
take = partial(indexing.rewriting_take,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode,
|
||||
fill_value=fill_value)
|
||||
|
1248
jax/_src/numpy/indexing.py
Normal file
1248
jax/_src/numpy/indexing.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -31,6 +31,8 @@ from jax._src import deprecations
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.lax import PrecisionLike
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.numpy import einsum
|
||||
from jax._src.numpy import indexing
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import reductions, tensor_contractions, ufuncs
|
||||
from jax._src.numpy.util import promote_dtypes_inexact, ensure_arraylike
|
||||
@ -292,7 +294,7 @@ def svd(
|
||||
s = lax.rev(s, dimensions=[s.ndim - 1])
|
||||
idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
|
||||
sign = lax.rev(sign, dimensions=[s.ndim - 1])
|
||||
u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
|
||||
u = indexing.take_along_axis(w, idxs[..., None, :], axis=-1)
|
||||
vh = _H(u * sign[..., None, :].astype(u.dtype))
|
||||
return SVDResult(u, s, vh)
|
||||
else:
|
||||
@ -2115,8 +2117,8 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -
|
||||
einsum_axes[0] = einsum_axes[0][1:]
|
||||
if arrs[-1].ndim == 1:
|
||||
einsum_axes[-1] = einsum_axes[-1][:1]
|
||||
return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload]
|
||||
optimize='auto', precision=precision)
|
||||
return einsum.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload]
|
||||
optimize='auto', precision=precision)
|
||||
|
||||
|
||||
@export
|
||||
|
@ -28,9 +28,10 @@ from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.array_creation import empty, full, full_like, ones, zeros
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
append, arange, concatenate, diff, empty, full, full_like,
|
||||
moveaxis, nonzero, ones, ravel, sort, where, zeros)
|
||||
append, arange, concatenate, diff,
|
||||
moveaxis, nonzero, ravel, sort, where)
|
||||
from jax._src.numpy.reductions import any, cumsum
|
||||
from jax._src.numpy.sorting import lexsort
|
||||
from jax._src.numpy.ufuncs import isnan
|
||||
|
@ -25,6 +25,7 @@ from typing import Any
|
||||
import jax
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy import indexing
|
||||
import jax._src.numpy.lax_numpy as jnp
|
||||
from jax._src.numpy.reductions import _moveaxis
|
||||
from jax._src.numpy.util import check_arraylike, _broadcast_to, _where
|
||||
@ -442,7 +443,7 @@ class ufunc:
|
||||
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
|
||||
a = lax_internal.asarray(a).astype(dtype)
|
||||
args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args)
|
||||
indices = jnp._eliminate_deprecated_list_indexing(indices)
|
||||
indices = indexing.eliminate_deprecated_list_indexing(indices)
|
||||
if not indices:
|
||||
return a
|
||||
|
||||
@ -517,7 +518,7 @@ class ufunc:
|
||||
dtype: DTypeLike | None = None) -> Array:
|
||||
check_arraylike(f"{self.__name__}.reduceat", a, indices)
|
||||
a = lax_internal.asarray(a)
|
||||
idx_tuple = jnp._eliminate_deprecated_list_indexing(indices)
|
||||
idx_tuple = indexing.eliminate_deprecated_list_indexing(indices)
|
||||
assert len(idx_tuple) == 1
|
||||
indices = idx_tuple[0]
|
||||
if a.ndim == 0:
|
||||
@ -529,14 +530,14 @@ class ufunc:
|
||||
if axis is None or isinstance(axis, (tuple, list)):
|
||||
raise ValueError("reduceat requires a single integer axis.")
|
||||
axis = canonicalize_axis(axis, a.ndim)
|
||||
out = jnp.take(a, indices, axis=axis)
|
||||
out = indexing.take(a, indices, axis=axis)
|
||||
ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]),
|
||||
list(np.delete(np.arange(out.ndim), axis)))
|
||||
ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis)
|
||||
ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis)
|
||||
def loop_body(i, out):
|
||||
return _where((i > ind_start) & (i < ind_end),
|
||||
self(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
|
||||
self(out, indexing.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
|
||||
out)
|
||||
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)
|
||||
|
||||
|
@ -29,6 +29,7 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy import indexing
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes
|
||||
@ -72,7 +73,7 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
|
||||
|
||||
# XLA gathers and scatters are very similar in structure; the scatter logic
|
||||
# is more or less a transpose of the gather equivalent.
|
||||
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)
|
||||
treedef, static_idx, dynamic_idx = indexing.split_index_for_jit(idx, x.shape)
|
||||
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
indices_are_sorted, unique_indices, mode,
|
||||
normalize_indices)
|
||||
@ -96,9 +97,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
"In future JAX releases this will result in an error.",
|
||||
FutureWarning)
|
||||
|
||||
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
||||
indexer = jnp._index_to_gather(jnp.shape(x), idx,
|
||||
normalize_indices=normalize_indices)
|
||||
idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
||||
indexer = indexing.index_to_gather(jnp.shape(x), idx,
|
||||
normalize_indices=normalize_indices)
|
||||
|
||||
# Avoid calling scatter if the slice shape is empty, both as a fast path and
|
||||
# to handle cases like zeros(0)[array([], int32)].
|
||||
|
@ -70,7 +70,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
||||
from jax.util import safe_map, safe_zip, split_list
|
||||
from jax._src.lax.control_flow import _check_tree_and_avals
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.numpy import indexing as jnp_indexing
|
||||
from jax.experimental import sparse
|
||||
from jax.experimental.sparse import BCOO, BCSR
|
||||
|
||||
@ -914,7 +914,7 @@ def _bcoo_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=Fals
|
||||
mode=None, fill_value=None):
|
||||
# Only sparsify the array argument; sparse indices not yet supported
|
||||
result = sparsify(functools.partial(
|
||||
lax_numpy._rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
|
||||
jnp_indexing.rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
|
||||
mode=mode, unique_indices=unique_indices, fill_value=fill_value))(arr)
|
||||
# Account for a corner case in the rewriting_take implementation.
|
||||
if not isinstance(result, BCOO) and np.size(result) == 0:
|
||||
@ -966,7 +966,7 @@ def _bcsr_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=Fals
|
||||
mode=None, fill_value=None):
|
||||
# Only sparsify the array argument; sparse indices not yet supported
|
||||
result = sparsify(functools.partial(
|
||||
lax_numpy._rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
|
||||
jnp_indexing.rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
|
||||
mode=mode, unique_indices=unique_indices, fill_value=fill_value))(arr)
|
||||
return result
|
||||
|
||||
|
@ -144,11 +144,8 @@ from jax._src.numpy.lax_numpy import (
|
||||
permute_dims as permute_dims,
|
||||
pi as pi,
|
||||
piecewise as piecewise,
|
||||
place as place,
|
||||
printoptions as printoptions,
|
||||
promote_types as promote_types,
|
||||
put as put,
|
||||
put_along_axis as put_along_axis,
|
||||
ravel as ravel,
|
||||
ravel_multi_index as ravel_multi_index,
|
||||
repeat as repeat,
|
||||
@ -170,8 +167,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
squeeze as squeeze,
|
||||
stack as stack,
|
||||
swapaxes as swapaxes,
|
||||
take as take,
|
||||
take_along_axis as take_along_axis,
|
||||
tile as tile,
|
||||
trace as trace,
|
||||
trapezoid as trapezoid,
|
||||
@ -211,6 +206,14 @@ from jax._src.numpy.einsum import (
|
||||
einsum_path as einsum_path,
|
||||
)
|
||||
|
||||
from jax._src.numpy.indexing import (
|
||||
place as place,
|
||||
put as put,
|
||||
put_along_axis as put_along_axis,
|
||||
take as take,
|
||||
take_along_axis as take_along_axis,
|
||||
)
|
||||
|
||||
from jax._src.numpy.scalar_types import (
|
||||
bfloat16 as bfloat16,
|
||||
bool_ as bool, # Array API alias for bool_ # noqa: F401
|
||||
|
Loading…
x
Reference in New Issue
Block a user