refactor: move lax_numpy indexing routines to their own submodule

This commit is contained in:
Jake VanderPlas 2025-02-12 11:52:11 -08:00
parent 5ebb7eb55d
commit f750d0b855
10 changed files with 1291 additions and 1250 deletions

View File

@ -383,7 +383,7 @@ class ArrayImpl(basearray.Array):
def __getitem__(self, idx):
from jax._src.lax import lax
from jax._src.numpy import lax_numpy
from jax._src.numpy import indexing
self._check_if_deleted()
if isinstance(self.sharding, PmapSharding):
@ -418,7 +418,7 @@ class ArrayImpl(basearray.Array):
return ArrayImpl(
out.aval, sharding, [out], committed=False, _skip_checks=True)
return lax_numpy._rewriting_take(self, idx)
return indexing.rewriting_take(self, idx)
def __iter__(self):
if self.ndim == 0:

View File

@ -40,6 +40,7 @@ from jax._src.array import ArrayImpl
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.numpy import array_api_metadata
from jax._src.numpy import indexing
from jax._src.numpy import lax_numpy
from jax._src.numpy import tensor_contractions
from jax._src import mesh as mesh_lib
@ -382,8 +383,8 @@ def _take(self: Array, indices: ArrayLike, axis: int | None = None, out: None =
Refer to :func:`jax.numpy.take` for full documentation.
"""
return lax_numpy.take(self, indices, axis=axis, out=out, mode=mode, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value)
return indexing.take(self, indices, axis=axis, out=out, mode=mode, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value)
def _to_device(self: Array, device: xc.Device | Sharding, *,
stream: int | Any | None = None):
@ -649,7 +650,7 @@ def _chunk_iter(x, size):
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
def _getitem(self, item):
return lax_numpy._rewriting_take(self, item)
return indexing.rewriting_take(self, item)
# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
@ -777,7 +778,7 @@ class _IndexUpdateRef:
See :mod:`jax.ops` for details.
"""
take = partial(lax_numpy._rewriting_take,
take = partial(indexing.rewriting_take,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode,
fill_value=fill_value)

1248
jax/_src/numpy/indexing.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -31,6 +31,8 @@ from jax._src import deprecations
from jax._src.lax import lax as lax_internal
from jax._src.lax.lax import PrecisionLike
from jax._src.lax import linalg as lax_linalg
from jax._src.numpy import einsum
from jax._src.numpy import indexing
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions, tensor_contractions, ufuncs
from jax._src.numpy.util import promote_dtypes_inexact, ensure_arraylike
@ -292,7 +294,7 @@ def svd(
s = lax.rev(s, dimensions=[s.ndim - 1])
idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
sign = lax.rev(sign, dimensions=[s.ndim - 1])
u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
u = indexing.take_along_axis(w, idxs[..., None, :], axis=-1)
vh = _H(u * sign[..., None, :].astype(u.dtype))
return SVDResult(u, s, vh)
else:
@ -2115,8 +2117,8 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -
einsum_axes[0] = einsum_axes[0][1:]
if arrs[-1].ndim == 1:
einsum_axes[-1] = einsum_axes[-1][:1]
return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload]
optimize='auto', precision=precision)
return einsum.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload]
optimize='auto', precision=precision)
@export

View File

@ -28,9 +28,10 @@ from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.numpy.array_creation import empty, full, full_like, ones, zeros
from jax._src.numpy.lax_numpy import (
append, arange, concatenate, diff, empty, full, full_like,
moveaxis, nonzero, ones, ravel, sort, where, zeros)
append, arange, concatenate, diff,
moveaxis, nonzero, ravel, sort, where)
from jax._src.numpy.reductions import any, cumsum
from jax._src.numpy.sorting import lexsort
from jax._src.numpy.ufuncs import isnan

View File

@ -25,6 +25,7 @@ from typing import Any
import jax
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.lax import lax as lax_internal
from jax._src.numpy import indexing
import jax._src.numpy.lax_numpy as jnp
from jax._src.numpy.reductions import _moveaxis
from jax._src.numpy.util import check_arraylike, _broadcast_to, _where
@ -442,7 +443,7 @@ class ufunc:
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
a = lax_internal.asarray(a).astype(dtype)
args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args)
indices = jnp._eliminate_deprecated_list_indexing(indices)
indices = indexing.eliminate_deprecated_list_indexing(indices)
if not indices:
return a
@ -517,7 +518,7 @@ class ufunc:
dtype: DTypeLike | None = None) -> Array:
check_arraylike(f"{self.__name__}.reduceat", a, indices)
a = lax_internal.asarray(a)
idx_tuple = jnp._eliminate_deprecated_list_indexing(indices)
idx_tuple = indexing.eliminate_deprecated_list_indexing(indices)
assert len(idx_tuple) == 1
indices = idx_tuple[0]
if a.ndim == 0:
@ -529,14 +530,14 @@ class ufunc:
if axis is None or isinstance(axis, (tuple, list)):
raise ValueError("reduceat requires a single integer axis.")
axis = canonicalize_axis(axis, a.ndim)
out = jnp.take(a, indices, axis=axis)
out = indexing.take(a, indices, axis=axis)
ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]),
list(np.delete(np.arange(out.ndim), axis)))
ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis)
ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis)
def loop_body(i, out):
return _where((i > ind_start) & (i < ind_end),
self(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
self(out, indexing.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
out)
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)

View File

@ -29,6 +29,7 @@ from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.lax import lax as lax_internal
from jax._src.numpy import indexing
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions
from jax._src.numpy.util import check_arraylike, promote_dtypes
@ -72,7 +73,7 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
# XLA gathers and scatters are very similar in structure; the scatter logic
# is more or less a transpose of the gather equivalent.
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)
treedef, static_idx, dynamic_idx = indexing.split_index_for_jit(idx, x.shape)
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
indices_are_sorted, unique_indices, mode,
normalize_indices)
@ -96,9 +97,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
"In future JAX releases this will result in an error.",
FutureWarning)
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx,
normalize_indices=normalize_indices)
idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = indexing.index_to_gather(jnp.shape(x), idx,
normalize_indices=normalize_indices)
# Avoid calling scatter if the slice shape is empty, both as a fast path and
# to handle cases like zeros(0)[array([], int32)].

View File

@ -70,7 +70,7 @@ from jax._src.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import safe_map, safe_zip, split_list
from jax._src.lax.control_flow import _check_tree_and_avals
from jax._src.numpy import lax_numpy
from jax._src.numpy import indexing as jnp_indexing
from jax.experimental import sparse
from jax.experimental.sparse import BCOO, BCSR
@ -914,7 +914,7 @@ def _bcoo_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=Fals
mode=None, fill_value=None):
# Only sparsify the array argument; sparse indices not yet supported
result = sparsify(functools.partial(
lax_numpy._rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
jnp_indexing.rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
mode=mode, unique_indices=unique_indices, fill_value=fill_value))(arr)
# Account for a corner case in the rewriting_take implementation.
if not isinstance(result, BCOO) and np.size(result) == 0:
@ -966,7 +966,7 @@ def _bcsr_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=Fals
mode=None, fill_value=None):
# Only sparsify the array argument; sparse indices not yet supported
result = sparsify(functools.partial(
lax_numpy._rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
jnp_indexing.rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
mode=mode, unique_indices=unique_indices, fill_value=fill_value))(arr)
return result

View File

@ -144,11 +144,8 @@ from jax._src.numpy.lax_numpy import (
permute_dims as permute_dims,
pi as pi,
piecewise as piecewise,
place as place,
printoptions as printoptions,
promote_types as promote_types,
put as put,
put_along_axis as put_along_axis,
ravel as ravel,
ravel_multi_index as ravel_multi_index,
repeat as repeat,
@ -170,8 +167,6 @@ from jax._src.numpy.lax_numpy import (
squeeze as squeeze,
stack as stack,
swapaxes as swapaxes,
take as take,
take_along_axis as take_along_axis,
tile as tile,
trace as trace,
trapezoid as trapezoid,
@ -211,6 +206,14 @@ from jax._src.numpy.einsum import (
einsum_path as einsum_path,
)
from jax._src.numpy.indexing import (
place as place,
put as put,
put_along_axis as put_along_axis,
take as take,
take_along_axis as take_along_axis,
)
from jax._src.numpy.scalar_types import (
bfloat16 as bfloat16,
bool_ as bool, # Array API alias for bool_ # noqa: F401