# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import enum from functools import partial import itertools import typing from typing import Any from absl.testing import absltest from absl.testing import parameterized import numpy as np import jax from jax import lax from jax import numpy as jnp from jax import ops from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu from jax._src import util from jax._src.lax import lax as lax_internal from jax._src.util import NumpyComplexWarning config.parse_flags_with_absl() # We disable the whitespace continuation check in this file because otherwise it # makes the test name formatting unwieldy. # pylint: disable=bad-continuation ARRAY_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[array\(seq\)\]" TUPLE_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[tuple\(seq\)\]" float_dtypes = jtu.dtypes.floating default_dtypes = float_dtypes + jtu.dtypes.integer all_dtypes = default_dtypes + jtu.dtypes.boolean class IndexSpec(typing.NamedTuple): shape: tuple[int, ...] indexer: Any out_shape: tuple[int, ...] | None = None def check_grads(f, args, order, atol=None, rtol=None, eps=None): # TODO(mattjj,dougalm): add higher-order check default_tol = 1e-6 if config.enable_x64.value else 1e-2 atol = atol or default_tol rtol = rtol or default_tol eps = eps or default_tol jtu.check_jvp(f, partial(jax.jvp, f), args, atol, rtol, eps) jtu.check_vjp(f, partial(jax.vjp, f), args, atol, rtol, eps) STATIC_INDEXING_TESTS = [ ("OneIntIndex", [ IndexSpec(shape=(3,), indexer=1, out_shape=()), IndexSpec(shape=(3, 3), indexer=0, out_shape=(3,)), IndexSpec(shape=(3, 4, 5), indexer=2, out_shape=(4, 5)), IndexSpec(shape=(3,), indexer=-1, out_shape=()), IndexSpec(shape=(3,), indexer=-2, out_shape=()), ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, 1), out_shape=()), IndexSpec(shape=(3, 4, 5), indexer=(1, 2), out_shape=(5,)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2), out_shape=(5,)), ]), ("ThreeIntIndices", [ IndexSpec(shape=(3, 4, 5), indexer=(1, 2, 3), out_shape=()), ]), ("OneSliceIndex", [ IndexSpec(shape=(10,), indexer=slice(1, 3), out_shape=(2,)), IndexSpec(shape=(10,), indexer=slice(1, -1), out_shape=(8,)), IndexSpec(shape=(10,), indexer=slice(None, -1), out_shape=(9,)), IndexSpec(shape=(10,), indexer=slice(None, None, None), out_shape=(10,)), IndexSpec(shape=(10, 8), indexer=slice(1, 3), out_shape=(2, 8)), IndexSpec(shape=(10, 8), indexer=slice(1, None), out_shape=(9, 8)), IndexSpec(shape=(10, 8), indexer=slice(None, 3), out_shape=(3, 8)), IndexSpec(shape=(10, 8), indexer=slice(-3, None), out_shape=(3, 8)), ]), ("OneSliceIndexNegativeStride", [ IndexSpec(shape=(10,), indexer=slice(3, 1, -1), out_shape=(2,)), IndexSpec(shape=(10,), indexer=slice(1, 8, -1), out_shape=(0,)), IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)), IndexSpec(shape=(10,), indexer=slice(None, None, -1), out_shape=(10,)), IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1), out_shape=(2, 8)), IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1), out_shape=(0, 8)), IndexSpec(shape=(10, 8), indexer=slice(None, None, -1), out_shape=(10, 8)), ]), ("SliceIndexClamping", [ IndexSpec(shape=(10,), indexer=slice(2, 11, 1), out_shape=(8,)), IndexSpec(shape=(10,), indexer=slice(11, 12, 1), out_shape=(0,)), IndexSpec(shape=(10,), indexer=slice(-11, -2, 1), out_shape=(8,)), IndexSpec(shape=(10,), indexer=slice(-2, -12, -1), out_shape=(9,)), IndexSpec(shape=(10,), indexer=slice(12, -12, -1), out_shape=(10,)), ]), ("OneSliceIndexNonUnitStride", [ IndexSpec(shape=(10,), indexer=slice(0, 8, 2), out_shape=(4,)), IndexSpec(shape=(10,), indexer=slice(0, 8, 3), out_shape=(3,)), IndexSpec(shape=(10,), indexer=slice(1, 3, 2), out_shape=(1,)), IndexSpec(shape=(10,), indexer=slice(1, None, 2), out_shape=(5,)), IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)), IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3), out_shape=(3, 8)), IndexSpec(shape=(10, 8), indexer=slice(None, None, 2), out_shape=(5, 8)), IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2), out_shape=(4, 8)), IndexSpec(shape=(10, 8), indexer=slice(None, None, -2), out_shape=(5, 8)), ]), ("TwoSliceIndices", [ IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2)), out_shape=(2, 2)), IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2)), out_shape=(9, 2)), IndexSpec(shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2)), out_shape=(10, 2)), IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2)), out_shape=(2, 2, 3)), IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None)), out_shape=(2, 8, 3)), IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2)), out_shape=(9, 2, 3)), ]), ("OneColonIndex", [ IndexSpec(shape=(3,), indexer=slice(None), out_shape=(3,)), IndexSpec(shape=(3, 4), indexer=slice(None), out_shape=(3, 4)), ]), ("MultipleColonIndices", [ IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None)), out_shape=(3, 4)), IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None)), out_shape=(3, 4, 5)), ]), ("MixedSliceIndices", [ IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2)), out_shape=(10, 2)), IndexSpec(shape=(10, 4), indexer=(1, slice(None)), out_shape=(4,)), ]), ("EllipsisIndex", [ IndexSpec(shape=(3,), indexer=Ellipsis, out_shape=(3,)), IndexSpec(shape=(3, 4), indexer=Ellipsis, out_shape=(3, 4)), IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis), out_shape=(4, 5)), IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3), out_shape=(3,)), ]), ("NoneIndex", [ IndexSpec(shape=(), indexer=None, out_shape=(1,)), IndexSpec(shape=(), indexer=(None, None), out_shape=(1, 1)), IndexSpec(shape=(), indexer=(Ellipsis, None), out_shape=(1,)), IndexSpec(shape=(3,), indexer=None, out_shape=(1, 3)), IndexSpec(shape=(3, 4), indexer=None, out_shape=(1, 3, 4)), IndexSpec(shape=(3, 4), indexer=(Ellipsis, None), out_shape=(3, 4, 1)), IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis), out_shape=(1, 4)), IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis), out_shape=(1, 4, 5)), ]), ("EmptyIndex", [ IndexSpec(shape=(), indexer=(), out_shape=()), IndexSpec(shape=(3,), indexer=(), out_shape=(3,)), IndexSpec(shape=(3, 4), indexer=(), out_shape=(3, 4)), ]), ("TupleOfIntAndSliceAndIntArray", [ IndexSpec(shape=(3, 2, 3), indexer=(0, slice(None), np.arange(3)), out_shape=(3, 2)), IndexSpec(shape=(3, 2, 3), indexer=(np.int32(1), slice(None), np.arange(3)), out_shape=(3, 2)), IndexSpec(shape=(3, 2, 3), indexer=(np.array(2), slice(None), np.arange(3)), out_shape=(3, 2)), ]), ] STATIC_INDEXING_OUT_OF_BOUNDS_TESTS = [ ("OneIntIndex", [ IndexSpec(shape=(3,), indexer=-4, out_shape=()), IndexSpec(shape=(3, 3), indexer=3, out_shape=(3,)), IndexSpec(shape=(3, 4, 5), indexer=4, out_shape=(4, 5)), ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, -4), out_shape=()), IndexSpec(shape=(3, 4, 5), indexer=(3, 2), out_shape=()), IndexSpec(shape=(3, 4, 5), indexer=(-4, 4), out_shape=(5,)), ]), ] ADVANCED_INDEXING_TESTS = [ ("One1DIntArrayIndex", [ IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)), IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1]), out_shape=(3, 3)), IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1]), out_shape=(4, 4, 5)), IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)), IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)), IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)), ]), ("One2DIntArrayIndex", [ IndexSpec(shape=(3,), indexer=np.array([[0, 0]]),out_shape=(1, 2)), IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1], [0, 1, -1]]), out_shape=(2, 3, 3)), IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1], [-1, -2, 1, 0]]), out_shape=(2, 4, 4, 5)), ]), ("Two1DIntArrayIndicesNoBroadcasting", [ IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])), out_shape=(2,)), IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2, 0, 1]), np.array([-1, 0, -1, 2])), out_shape=(4, 5)), ]), ("Two1DIntArrayIndicesWithBroadcasting", [ IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])), out_shape=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(np.array([[0, 2, 0, 1]]), np.array([-1, 0, -1, 2])), out_shape=(1, 4, 5)), ]), ("ArrayOfInts", [ IndexSpec(shape=(3,), indexer=np.array([0, 1, 0]), out_shape=(3,)), IndexSpec(shape=(3, 4, 5), indexer=np.array([ 0, -1]), out_shape=(2, 4, 5)), ]), ("TupleOfListsOfPythonInts", [ IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]]), out_shape=(2, 4, 5)), ]), ("TupleOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)), IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0, 3]])), out_shape=(1, 4)), ]), ("TupleOfListsOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])), out_shape=(2, 5)), IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0, 3]])), out_shape=(2, 4, 5)), ]), ] ADVANCED_INDEXING_TESTS_NO_REPEATS = [ ("One1DIntArrayIndex", [ IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)), IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 0]), out_shape=(3, 3)), IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 1]), out_shape=(3, 4, 5)), IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)), IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)), IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)), ]), ("One2DIntArrayIndex", [ IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)), IndexSpec(shape=(6, 6), indexer=np.array([[1, 2, 0], [3, 4, -1]]), out_shape=(2, 3, 6)), ]), ("Two1DIntArrayIndicesNoBroadcasting", [ IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])), out_shape=(2,)), IndexSpec(shape=(4, 5, 6), indexer=(np.array([0, 2, 1, 3]), np.array([-1, 0, -2, 1])), out_shape=(4, 6)), ]), ("Two1DIntArrayIndicesWithBroadcasting", [ IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])), out_shape=(1, 2)), IndexSpec(shape=(4, 5, 6), indexer=(np.array([[0, 2, -1, 1]]), np.array([-1, 0, -2, 2])), out_shape=(1, 4, 6)), ]), ("ArrayOfInts", [ IndexSpec(shape=(3,), indexer=np.array([0, 2, 1]), out_shape=(3,)), IndexSpec(shape=(3, 4, 5), indexer=np.array([ 0, -1]), out_shape=(2, 4, 5)), ]), ("TupleOfListsOfPythonInts", [ IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]]), out_shape=(2, 3, 5)), ]), ("TupleOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)), IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0]])), out_shape=(1, 3)), ]), ("TupleOfListsOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])), out_shape=(2, 5)), IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0]])), out_shape=(2, 3, 5)), ]), ] ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED = [ ("One1DIntArrayIndex", [ IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)), IndexSpec(shape=(3, 3), indexer=np.array([0, 1, 2]), out_shape=(3, 3)), IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 1, 2]), out_shape=(3, 4, 5)), IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)), IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)), IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)), ]), ("One2DIntArrayIndex", [ IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)), IndexSpec(shape=(6, 6), indexer=np.array([[-1, 0, 1], [ 2, 3, 4]]), out_shape=(2, 3, 6)), ]), ("Two1DIntArrayIndicesNoBroadcasting", [ IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])), out_shape=(2,)), IndexSpec(shape=(4, 5, 6), indexer=(np.array([0, 1, 2, 3]), np.array([-2, -1, 0, 1])), out_shape=(4, 6)), ]), ("Two1DIntArrayIndicesWithBroadcasting", [ IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])), out_shape=(1, 2)), IndexSpec(shape=(4, 5, 6), indexer=(np.array([[-1, 0, 1, 2]]), np.array([-2, -1, 0, 2])), out_shape=(1, 4, 6)), ]), ("TupleOfListsOfPythonInts", [ IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[0, 2, 3]]), out_shape=(2, 3, 5)), ]), ("TupleOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)), IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[0, 2, 3]])), out_shape=(1, 3)), ]), ("TupleOfListsOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])), out_shape=(2, 5)), IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[0, 2, 3]])), out_shape=(2, 3, 5)), ]), ] MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [ ("SlicesAndOneIntArrayIndex", [ IndexSpec(shape=(2, 3), indexer=(np.array([0, 1]), slice(1, 2)), out_shape=(2, 1)), IndexSpec(shape=(2, 3), indexer=(slice(0, 2), np.array([0, 2])), out_shape=(2, 2)), IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, np.array([0, 2]), slice(None)), out_shape=(3, 2, 5)), IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, np.array([[0, 2], [1, 3]]), slice(None)), out_shape=(3, 2, 2, 5)), ]), ("SlicesAndTwoIntArrayIndices", [ IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, np.array([0, 2]), np.array([-1, 2])), out_shape=(3, 2)), IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]), Ellipsis, np.array([-1, 2])), out_shape=(2, 4)), IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]), np.array([-1, 2]), Ellipsis), out_shape=(2, 5)), IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]), np.array([-1, 2]), slice(1, 3)), out_shape=(2, 2)), IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]), slice(1, 3), np.array([-1, 2])), out_shape=(2, 2)), IndexSpec(shape=(3, 4, 5), indexer=(np.array([ 0, 2, -2]), slice(None, None, 2), np.array([-1, 2, 1])), out_shape=(3, 2)), ]), ("NonesAndIntArrayIndices", [ IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]), None, np.array([-1, 2])), out_shape=(2, 1, 5)), IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]), None, None, np.array([-1, 2])), out_shape=(2, 1, 1, 5)), IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, np.array([0, 2]), None, None, np.array([-1, 2])), out_shape=(2, 3, 1, 1)), ]), ("IntArrayWithInt32Type", [ IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)), out_shape=(3,)), ]), ("EllipsisWithArrayIndices", [ IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 1]), ..., np.array([0, 1])), out_shape=(2, 4)), IndexSpec(shape=(3, 4, 5), indexer=(slice(None), np.array([0, 1]), ..., np.array([0, 1])), out_shape=(2, 3)), IndexSpec(shape=(3, 4, 5), indexer=(slice(None), ..., np.array([0, 1]), np.array([0, 1])), out_shape=(3, 2)), ]), ] MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [ ("SlicesAndOneIntArrayIndex", [ IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, np.array([[0, 2], [1, 1]]), slice(None)), out_shape=(3, 2, 2, 5)), ]), ("SlicesAndTwoIntArrayIndices", [ IndexSpec(shape=(3, 4, 5), indexer=(np.array([ 0, 2, -2]), slice(None, None, 2), np.array([-1, 2, -1])), out_shape=(3, 2)), IndexSpec(shape=(3, 4, 5), indexer=(np.array([[0, 2], [2, 0]]), Ellipsis, np.array([[1, 0], [1, 0]])), out_shape=(2, 2, 4)), ]), ] MODES = ["clip", "drop", "promise_in_bounds"] class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer) for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer, _ in index_specs], dtype=all_dtypes ) def testStaticIndexing(self, name, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] np_fun = lambda x: np.asarray(x)[indexer] jnp_fun = lambda x: jnp.asarray(x)[indexer] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) # Tests x.at[...].get(...) as well. jnp_fun = lambda x: jnp.asarray(x).at[indexer].get() self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) def testStaticIndexingWithJaxArray(self): shape = (10,) indexer = slice(jnp.array(2, dtype=np.int32), np.array(11, dtype=np.int32), jnp.array(1, dtype=np.int32)) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, np.int32)] np_fun = lambda x: np.asarray(x)[indexer] jnp_fun = lambda x: jnp.asarray(x)[indexer] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) # Tests x.at[...].get(...) as well. jnp_fun = lambda x: jnp.asarray(x).at[indexer].get() self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( funcname=["negative", "sin", "cos", "square", "sqrt", "log", "exp"], ) def testIndexApply(self, funcname, size=10, dtype='float32'): rng = jtu.rand_default(self.rng()) idx_rng = jtu.rand_int(self.rng(), -size, size) np_func = getattr(np, funcname) jnp_func = getattr(jnp, funcname) @jtu.ignore_warning(category=RuntimeWarning) def np_op(x, idx): y = x.copy() np_func.at(y, idx) return y def jnp_op(x, idx): return jnp.asarray(x).at[idx].apply(jnp_func) # Test with traced integer index args_maker = lambda: [rng(size, dtype), idx_rng(size, int)] tol = ( 5e-5 if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp") else None ) self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=tol) self._CompileAndCheck(jnp_op, args_maker) # Test with slice index idx = slice(1, 5) np_op_idx = partial(np_op, idx=idx) jnp_op_idx = partial(jnp_op, idx=idx) args_maker = lambda: [rng(size, dtype)] self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_op_idx, args_maker) def testIndexApplyBatchingBug(self): # https://github.com/jax-ml/jax/issues/16655 arr = jnp.array([[1, 2, 3, 4, 5, 6]]) ind = jnp.array([3]) func = lambda a, i: a.at[i].apply(lambda x: x - 1) expected = jnp.array(list(map(func, arr, ind))) out = jax.vmap(func)(arr, ind) self.assertArraysEqual(out, expected) def testIndexUpdateScalarBug(self): # https://github.com/jax-ml/jax/issues/14923 a = jnp.arange(10.) out = a.at[0].apply(jnp.cos) self.assertArraysEqual(out, a.at[0].set(1)) @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer, mode=mode) for mode in MODES for name, index_specs in ( STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS) for shape, indexer, _ in index_specs ], dtype=float_dtypes, ) def testStaticIndexingGrads(self, name, shape, dtype, indexer, mode): rng = jtu.rand_default(self.rng()) tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) # Use an arbitrary finite fill_value, since NaNs won't work in a numerical # gradient test. fun = lambda x: jnp.asarray(x).at[indexer].get(mode=mode, fill_value=7)**2 check_grads(fun, (arg,), 2, tol, tol, tol) def _ReplaceSlicesWithTuples(self, idx): """Helper method to replace slices with tuples for dynamic indexing args.""" if isinstance(idx, slice): triple = idx.start, idx.stop, idx.step isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) out = util.subvals(triple, zip(isnone, zeros)) return out, lambda out: slice(*util.subvals(out, zip(isnone, nones))) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) else: return idx, lambda x: x @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer) for name, index_specs in [ ("OneSliceIndex", [IndexSpec(shape=(5,), indexer=slice(1, 3)), IndexSpec(shape=(5, 4), indexer=slice(1, 3))]), ("TwoSliceIndices", [IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]), ("NonUnitStrides", [ IndexSpec(shape=(3,), indexer=slice(None, None, -1)), IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) ]), ("OnlyStartOrStopDynamic", [ IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) ]), ] for shape, indexer, _ in index_specs ], dtype=all_dtypes, ) def testDynamicIndexingWithSlicesErrors(self, name, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @jax.jit def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self.assertRaises(IndexError, lambda: fun(*args_maker())) @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer) for name, index_specs in [ ("OneIntIndex", [IndexSpec(shape=(3,), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3,), indexer=-1), IndexSpec(shape=(3,), indexer=-2)]), ("TwoIntIndices", [IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer, _ in index_specs ], dtype=all_dtypes, ) def testDynamicIndexingWithIntegers(self, name, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) def np_fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return np.asarray(x)[indexer] def jnp_fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return jnp.array(x)[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer) for name, index_specs in [ ("OneIntIndex", [IndexSpec(shape=(3,), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3,), indexer=-1), IndexSpec(shape=(3,), indexer=-2), ]), ("TwoIntIndices", [IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer, _ in index_specs ], dtype=float_dtypes, ) def testDynamicIndexingWithIntegersGrads(self, name, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @jax.jit def fun(unpacked_indexer, x): indexer = pack_indexer(unpacked_indexer) return x[indexer] arr = rng(shape, dtype) check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol) @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer) for name, index_specs in ADVANCED_INDEXING_TESTS for shape, indexer, _ in index_specs ], dtype=all_dtypes, ) def testAdvancedIntegerIndexing(self, name, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), indexer] np_fun = lambda x, idx: np.asarray(x)[idx] jnp_fun = lambda x, idx: jnp.asarray(x)[idx] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product(dtype=jtu.dtypes.unsigned + jtu.dtypes.integer) def testIndicesNormalizationByType(self, dtype): x = jnp.arange(10) jaxpr = jax.make_jaxpr(x.__getitem__)(jnp.arange(3, dtype=dtype)) primitives = [eqn.primitive for eqn in jaxpr.eqns] if np.issubdtype(dtype, np.unsignedinteger): # Unsigned integers should not require lt, add, and select. self.assertEqual(primitives, [lax.convert_element_type_p, lax.broadcast_in_dim_p, lax.gather_p]) else: # May or may not contain convert_element_type. self.assertIn(len(primitives), [5, 6]) self.assertEqual(primitives[:3], [lax.lt_p, lax.add_p, lax.select_n_p]) self.assertEqual(primitives[-2:], [lax.broadcast_in_dim_p, lax.gather_p]) @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer) for name, index_specs in [ ("One1DIntArrayIndex", [IndexSpec(shape=(3,), indexer=np.array([0, 1])), IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1])), IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1])), IndexSpec(shape=(3,), indexer=np.array([-1, 1])), IndexSpec(shape=(3,), indexer=np.array([-2, -1])), ]), ("One2DIntArrayIndex", [IndexSpec(shape=(3,), indexer=np.array([[0, 0]])), IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1], [0, 1, -1]])), IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1], [-1, -2, 1, 0]])), ]), ("Two1DIntArrayIndicesNoBroadcasting", [IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2]))), IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2, 0, 1]), np.array([-1, 0, -1, 2]))), ]), ("Two1DIntArrayIndicesWithBroadcasting", [IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2]))), IndexSpec(shape=(3, 4, 5), indexer=(np.array([[0, 2, 0, 1]]), np.array([-1, 0, -1, 2]))), ]), ("TupleOfPythonIntsAndIntArrays", [IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1]))), IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0, 3]]))), ]), ("TupleOfListsOfPythonIntsAndIntArrays", [IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0]))), IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0, 3]]))), ]), ] for shape, indexer, _ in index_specs ], dtype=float_dtypes, ) def testAdvancedIntegerIndexingGrads(self, name, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: jnp.asarray(x)[indexer] check_grads(fun, (arg,), 2, tol, tol, eps=1.) @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer) for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS for shape, indexer, _ in index_specs ], dtype=all_dtypes, ) def testMixedAdvancedIntegerIndexing(self, name, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) indexer_with_dummies = [e if isinstance(e, np.ndarray) else () for e in indexer] substitutes = [(i, e) for i, e in enumerate(indexer) if not isinstance(e, np.ndarray)] args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] def jnp_fun(x, indexer_with_dummies): idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) return jnp.asarray(x)[idx] def np_fun(x, indexer_with_dummies): idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) return np.asarray(x)[idx] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) def testAdvancedIndexingManually(self): x = self.rng().randn(3, 4, 5) index_array = np.array([0, 2, -1, 0]) op = lambda x, index_array: x[..., index_array, :] cop = jax.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = jax.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = jax.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2) def testUnpacking(self): def foo(x): a, b, c = x return a + b + c cfoo = jax.jit(foo) a1 = foo(np.arange(3)) a2 = cfoo(np.arange(3)) self.assertAllClose(a1, a2) def testBooleanIndexingArray1D(self): idx = np.array([True, True, False]) x = jax.device_put(np.arange(3)) ans = x[idx] expected = np.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList1D(self): idx = [True, True, False] x = jax.device_put(np.arange(3)) with self.assertRaisesRegex(TypeError, ARRAY_MSG): x[idx] def testBooleanIndexingArray2DBroadcast(self): idx = np.array([True, True, False, True]) x = np.arange(8).reshape(4, 2) ans = jax.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList2DBroadcast(self): idx = [True, True, False, True] x = np.arange(8).reshape(4, 2) with self.assertRaisesRegex(TypeError, ARRAY_MSG): jax.device_put(x)[idx] def testBooleanIndexingArray2D(self): idx = np.array([[True, False], [False, True], [False, False], [True, True]]) x = np.arange(8).reshape(4, 2) ans = jax.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithEllipsis(self): # Regression test for https://github.com/jax-ml/jax/issues/8412 x = np.arange(24).reshape(4, 3, 2) idx = (..., np.array([True, False])) ans = jnp.array(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithEllipsis2(self): # Regression test for https://github.com/jax-ml/jax/issues/9050 x = np.arange(3) idx = (..., np.array([True, False, True])) ans = jnp.array(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithEllipsis3(self): x = np.arange(6).reshape(2, 3) idx = (0, ..., np.array([True, False, True])) ans = jnp.array(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean2DIndexingWithEllipsis(self): x = np.arange(24).reshape(4, 3, 2) idx = (..., np.array([[True, False], [True, False], [False, False]])) ans = jnp.array(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithTrailingEllipsis(self): x = np.arange(24).reshape(4, 3, 2) idx = (np.array([True, False, True, False]), ...) ans = jnp.array(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingDynamicShapeError(self): x = np.zeros(3) i = np.array([True, True, False]) self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i)) def testIssue187(self): x = jnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = np.arange(25).reshape((5, 5)) ans = jax.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False) def testJVPOfGradOfIndexing(self): # Should return a value, even though we didn't pass a symbolic zero as the # index tangent. x = jnp.ones((3, 4), jnp.float32) i = jnp.ones((3,), jnp.int32) f = lambda x, i: jnp.sum(x[i]) primals, tangents = jax.jvp(jax.grad(f), (x, i), (x, np.zeros(i.shape, dtypes.float0))) expected = np.broadcast_to( np.array([0, 3, 0], dtype=np.float32)[:, None], (3, 4)) self.assertAllClose(expected, primals) self.assertAllClose(np.zeros_like(x), tangents) def testSimpleIndexingUsesSlice(self): jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4))) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5))) self.assertEqual(len(jaxpr.jaxpr.eqns), 2) self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5))) self.assertEqual(len(jaxpr.jaxpr.eqns), 2) self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5))) self.assertEqual(len(jaxpr.jaxpr.eqns), 2) self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) # Indexing with `Ellipsis` is not lowered to `gather`. jaxpr = jax.make_jaxpr(lambda x: x[..., 0])(jnp.ones((3, 4, 5))) self.assertLen((jaxpr.jaxpr.eqns), 2) self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) # Simple reverses lower to lax.rev_p jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4))) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p) # Non-static indices produce a dynamic slice jaxpr = jax.make_jaxpr(lambda x, i: x[i])(jnp.ones((4,)), 2) self.assertEqual(len(jaxpr.jaxpr.eqns), 6) self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) def testTrivialGatherIsntGenerated(self): # https://github.com/jax-ml/jax/issues/1621 jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertNotIn('gather', str(jaxpr)) jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 0) jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 0) jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p) def testOOBEmptySlice(self): x = jnp.arange(4, dtype='float32') self.assertArraysEqual(x[1:0], jnp.empty(0, dtype='float32')) self.assertArraysEqual(x[-2:-10], jnp.empty(0, dtype='float32')) self.assertArraysEqual(x[5:10], jnp.empty(0, dtype='float32')) x = jnp.arange(6, dtype='float32').reshape(2, 3) self.assertArraysEqual(x[-1:-4], jnp.empty((0, 3), dtype='float32')) self.assertArraysEqual(x[:, 3:2], jnp.empty((2, 0), dtype='float32')) def testIndexingEmptyDimension(self): # Issue 2671: XLA error when indexing into dimension of size 0 x = jnp.ones((2, 0)) # The following work, even on axis 1 of size 0 with jax.numpy_rank_promotion('allow'): _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] with self.assertRaisesRegex(IndexError, "index .* is out of bounds for axis .* with size 0"): _ = np.ones((2, 0))[0, 0] # The numpy error with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): _ = x[0, 0] # JAX indexing with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): jax.jit(lambda i: x[0, i])(0) # JAX indexing under jit def testBooleanIndexingWithEmptyResult(self): # based on a TensorFlow Probability test that started failing after #1622 x = jnp.array([-1]) mask = jnp.array([False]) ans = x[mask] # doesn't crash expected = np.array([-1])[np.array([False])] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingShapeMismatch(self): # Regression test for https://github.com/jax-ml/jax/issues/7329 x = jnp.arange(4) idx = jnp.array([True, False]) with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"): x[idx] def testBooleanIndexingWithNone(self): # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) idx = (None, jnp.array([True, False])) ans = x[idx] expected = jnp.arange(3).reshape(1, 1, 3) self.assertAllClose(ans, expected) def testBooleanIndexingWithNoneAndEllipsis(self): # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) mask = jnp.array([True, False, False]) ans = x[None, ..., mask] expected = jnp.array([0, 3]).reshape(1, 2, 1) self.assertAllClose(ans, expected) def testBooleanIndexingWithEllipsisAndNone(self): # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) mask = jnp.array([True, False, False]) ans = x[..., None, mask] expected = jnp.array([0, 3]).reshape(2, 1, 1) self.assertAllClose(ans, expected) def testNontrivialBooleanIndexing(self): # Test nontrivial corner case in boolean indexing shape validation rng = jtu.rand_default(self.rng()) index = (rng((2, 3), np.bool_), rng((6,), np.bool_)) args_maker = lambda: [rng((2, 3, 6), np.int32)] np_fun = lambda x: np.asarray(x)[index] jnp_fun = lambda x: jnp.asarray(x)[index] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @parameterized.parameters( [(3,), (0,)], [(3, 4), (0,)], [(3, 4), (0, 4)], [(3, 4), (3, 0)], [(3, 4, 5), (3, 0)], ) def testEmptyBooleanIndexing(self, x_shape, m_shape): # Regression test for https://github.com/jax-ml/jax/issues/22886 rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(x_shape, np.int32), np.empty(m_shape, dtype=bool)] np_fun = lambda x, m: np.asarray(x)[np.asarray(m)] jnp_fun = lambda x, m: jnp.asarray(x)[jnp.asarray(m)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @jtu.sample_product( shape=[(2, 3, 4, 5)], idx=[ np.index_exp[True], np.index_exp[False], np.index_exp[..., True], np.index_exp[..., False], np.index_exp[0, :2, True], np.index_exp[0, :2, False], np.index_exp[:2, 0, True], np.index_exp[:2, 0, False], np.index_exp[:2, np.array([0, 2]), True], np.index_exp[np.array([1, 0]), :, True], np.index_exp[True, :, True, :, np.array(True)], ] ) def testScalarBooleanIndexing(self, shape, idx): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, np.int32)] np_fun = lambda x: np.asarray(x)[idx] jnp_fun = lambda x: jnp.asarray(x)[idx] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @jtu.sample_product( shape=[(2, 3, 4, 5)], update_ndim=[0, 1, 2], idx=[ np.index_exp[True], np.index_exp[False], np.index_exp[..., True], np.index_exp[..., False], np.index_exp[0, :2, True], np.index_exp[0, :2, False], np.index_exp[:2, 0, True], np.index_exp[:2, 0, False], np.index_exp[:2, np.array([0, 2]), True], np.index_exp[np.array([1, 0]), :, True], np.index_exp[True, :, True, :, np.array(True)], ] ) def testScalarBoolUpdate(self, shape, idx, update_ndim): update_shape = np.zeros(shape)[idx].shape[-update_ndim:] rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, np.int32), rng(update_shape, np.int32)] def np_fun(x, update): x = np.array(x, copy=True) x[idx] = update return x jnp_fun = lambda x, update: jnp.asarray(x).at[idx].set(update) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) def testFloatIndexingError(self): BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type" with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros(2)[0.] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[(0, 0.)] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[(0, 0.)] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jax.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.)) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros(2).at[0.].add(1.) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros(2).at[0.].set(1.) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[jnp.arange(2), 1.0] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[jnp.arange(2), 1 + 1j] def testStrIndexingError(self): msg = "JAX does not support string indexing" with self.assertRaisesRegex(TypeError, msg): jnp.zeros(2)['abc'] with self.assertRaisesRegex(TypeError, msg): jnp.zeros((2, 3))[:, 'abc'] def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 x = jnp.arange(5, dtype=jnp.int32) + 1 self.assertAllClose(x, x[:10]) idx = jnp.array([-10, -6, -5, -4, 0, 3, 4, 5, 6, 100]) self.assertArraysEqual( x.at[idx].get(mode="clip"), jnp.array([1, 1, 1, 2, 1, 4, 5, 5, 5, 5], jnp.int32)) nan = np.nan self.assertArraysEqual( x.astype(jnp.float32).at[idx].get(mode="fill"), jnp.array([nan, nan, 1, 2, 1, 4, 5, nan, nan, nan], jnp.float32)) imin = np.iinfo(np.int32).min self.assertArraysEqual( x.at[idx].get(mode="fill"), jnp.array([imin, imin, 1, 2, 1, 4, 5, imin, imin, imin], jnp.int32)) umax = np.iinfo(np.uint32).max self.assertArraysEqual( x.astype(np.uint32).at[idx].get(mode="fill"), jnp.array([umax, umax, 1, 2, 1, 4, 5, umax, umax, umax], jnp.uint32)) self.assertArraysEqual( x.at[idx].get(mode="fill", fill_value=7), jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32)) def testIndexingWeakTypes(self): x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True) a = x.at[0].set(1.0) self.assertEqual(a.dtype, x.dtype) self.assertTrue(dtypes.is_weakly_typed(a)) b = x.at[0].add(1.0) self.assertEqual(b.dtype, x.dtype) self.assertTrue(dtypes.is_weakly_typed(b)) c = x.at[0].mul(1.0) self.assertEqual(c.dtype, x.dtype) self.assertTrue(dtypes.is_weakly_typed(c)) def testIndexingTypePromotion(self): def _check(x_type, y_type): x = jnp.arange(5, dtype=x_type) y = y_type(0) out = x.at[0].set(y) self.assertEqual(x.dtype, out.dtype) @jtu.ignore_warning(category=NumpyComplexWarning, message="Casting complex values to real") def _check_warns(x_type, y_type, msg): with self.assertWarnsRegex(FutureWarning, msg): _check(x_type, y_type) def _check_raises(x_type, y_type, msg): with self.assertRaisesRegex(ValueError, msg): _check(x_type, y_type) # Matching dtypes are always OK _check(jnp.int32, jnp.int32) _check(jnp.float32, jnp.float32) _check(jnp.complex64, jnp.complex64) # Weakly-typed y values promote. _check(jnp.int32, int) _check(jnp.float32, int) _check(jnp.float32, float) _check(jnp.complex64, int) _check(jnp.complex64, float) _check(jnp.complex64, complex) # in standard promotion mode, strong types can promote. msg = "scatter inputs have incompatible types" with jax.numpy_dtype_promotion('standard'): _check(jnp.int32, jnp.int16) _check(jnp.float32, jnp.float16) _check(jnp.float32, jnp.int32) _check(jnp.complex64, jnp.int32) _check(jnp.complex64, jnp.float32) # TODO(jakevdp): make these _check_raises _check_warns(jnp.int16, jnp.int32, msg) _check_warns(jnp.int32, jnp.float32, msg) _check_warns(jnp.int32, jnp.complex64, msg) _check_warns(jnp.float16, jnp.float32, msg) _check_warns(jnp.float32, jnp.complex64, msg) # in strict promotion mode, strong types do not promote. msg = "Input dtypes .* have no available implicit dtype promotion path" with jax.numpy_dtype_promotion('strict'): _check_raises(jnp.int32, jnp.int16, msg) _check_raises(jnp.float32, jnp.float16, msg) _check_raises(jnp.float32, jnp.int32, msg) _check_raises(jnp.complex64, jnp.int32, msg) _check_raises(jnp.complex64, jnp.float32, msg) _check_raises(jnp.int16, jnp.int32, msg) _check_raises(jnp.int32, jnp.float32, msg) _check_raises(jnp.int32, jnp.complex64, msg) _check_raises(jnp.float16, jnp.float32, msg) _check_raises(jnp.float32, jnp.complex64, msg) def testWrongNumberOfIndices(self): with self.assertRaisesRegex( IndexError, "Too many indices: 0-dimensional array indexed with 1 regular index."): jnp.array(1)[0] with self.assertRaisesRegex( IndexError, "Too many indices: 1-dimensional array indexed with 2 regular indices."): jnp.zeros(3)[:, 5] def _broadcastable_shapes(shape): """Returns all shapes that broadcast to `shape`.""" def f(rshape): yield [] if rshape: for s in f(rshape[1:]): yield rshape[0:1] + s if rshape[0] != 1: for s in f(rshape[1:]): yield [1] + s for x in f(list(reversed(shape))): yield list(reversed(x)) # TODO(jakevdp): move this implementation to jax.dtypes & use in scatter? def _can_cast(from_, to): with jax.numpy_dtype_promotion('standard'): return lax.dtype(to) == dtypes.result_type(from_, to) def _compatible_dtypes(op, dtype, inexact=False): if op == UpdateOps.ADD or op == UpdateOps.SUB: return [dtype] elif inexact: return [dt for dt in float_dtypes if _can_cast(dt, dtype)] else: return [dt for dt in all_dtypes if _can_cast(dt, dtype)] class UpdateOps(enum.Enum): UPDATE = 0 ADD = 1 SUB = 2 MUL = 3 DIV = 4 POW = 5 MIN = 6 MAX = 7 def np_fn(op, indexer, x, y): x = x.copy() x[indexer] = { UpdateOps.UPDATE: lambda: y, UpdateOps.ADD: lambda: x[indexer] + y, UpdateOps.SUB: lambda: x[indexer] - y, UpdateOps.MUL: lambda: x[indexer] * y, UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)( lambda: x[indexer] / y.astype(x.dtype)), UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)( lambda: x[indexer] ** y.astype(x.dtype)), UpdateOps.MIN: lambda: np.minimum(x[indexer], y), UpdateOps.MAX: lambda: np.maximum(x[indexer], y), }[op]() return x def jax_fn(op, indexer, x, y, indices_are_sorted=False, unique_indices=False, mode=None): x = jnp.array(x) return { UpdateOps.UPDATE: x.at[indexer].set, UpdateOps.ADD: x.at[indexer].add, UpdateOps.SUB: x.at[indexer].subtract, UpdateOps.MUL: x.at[indexer].multiply, UpdateOps.DIV: x.at[indexer].divide, UpdateOps.POW: x.at[indexer].power, UpdateOps.MIN: x.at[indexer].min, UpdateOps.MAX: x.at[indexer].max, }[op](y, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) def dtypes(op): if op == UpdateOps.UPDATE: return all_dtypes elif op == UpdateOps.DIV or op == UpdateOps.POW: return jtu.dtypes.inexact else: return default_dtypes def _update_tol(op): if op == UpdateOps.POW: f32_tol = 2e-4 if jtu.test_device_matches(["tpu"]) else 1e-5 tol = {np.float32: f32_tol, np.complex64: f32_tol, np.complex128: 1e-14} else: tol = {np.complex128: 1e-14} return tol class IndexedUpdateTest(jtu.JaxTestCase): @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape) for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer, index_shape in index_specs for update_shape in _broadcastable_shapes(index_shape) ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) for op in UpdateOps for dtype in UpdateOps.dtypes(op) for update_dtype in _compatible_dtypes(op, dtype) ], mode=MODES, ) def testStaticIndexing(self, name, shape, dtype, update_shape, update_dtype, indexer, op, mode): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode) with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]): self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) self._CompileAndCheck(jax_fn, args_maker) @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape) for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer, index_shape in index_specs for update_shape in _broadcastable_shapes(index_shape) ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) for op in UpdateOps for dtype in UpdateOps.dtypes(op) for update_dtype in _compatible_dtypes(op, dtype) ], ) def testAdvancedIndexing(self, name, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, unique_indices=True) with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]): self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) self._CompileAndCheck(jax_fn, args_maker) @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape) for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED for shape, indexer, index_shape in index_specs for update_shape in _broadcastable_shapes(index_shape) ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) for op in UpdateOps for dtype in UpdateOps.dtypes(op) for update_dtype in _compatible_dtypes(op, dtype) ], ) def testAdvancedIndexingSorted(self, name, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn( op, indexer, x, y, indices_are_sorted=True, unique_indices=True) with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]): self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True, tol=_update_tol(op)) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @jtu.sample_product( [dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape) for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer, index_shape in index_specs for update_shape in _broadcastable_shapes(index_shape) ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) for op in UpdateOps for dtype in UpdateOps.dtypes(op) for update_dtype in _compatible_dtypes(op, dtype) ], ) def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]): self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) self._CompileAndCheck(jax_fn, args_maker) @jtu.sample_product( [dict(name=name, mode=mode, shape=shape, indexer=indexer, update_shape=update_shape) for mode in [None] + MODES for name, index_specs in ( STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS) for shape, indexer, index_shape in index_specs for update_shape in _broadcastable_shapes(index_shape) ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) for op in [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE] for dtype in float_dtypes for update_dtype in _compatible_dtypes(op, dtype, inexact=True) ], ) def testStaticIndexingGrads(self, name, shape, dtype, update_shape, update_dtype, indexer, op, mode): rng = jtu.rand_default(self.rng()) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode, unique_indices=True) x = rng(shape, dtype) y = rng(update_shape, update_dtype) with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]): check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=name, unique_indices=unique_indices, shape=shape, indexer=indexer, update_shape=update_shape) for name, index_specs in ( ADVANCED_INDEXING_TESTS_NO_REPEATS if unique_indices else ADVANCED_INDEXING_TESTS) for shape, indexer, index_shape in index_specs for update_shape in _broadcastable_shapes(index_shape) ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) for op in ( [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices else [UpdateOps.ADD, UpdateOps.SUB]) for dtype in float_dtypes for update_dtype in _compatible_dtypes(op, dtype, inexact=True) ], ) for unique_indices in [False, True] )) def testAdvancedIndexingGrads(self, name, shape, dtype, update_shape, update_dtype, indexer, op, unique_indices): rng = jtu.rand_default(self.rng()) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, unique_indices=unique_indices) x = rng(shape, dtype) y = rng(update_shape, update_dtype) with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]): check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) def testIndexMulGradFailsIfNotUnique(self): y = jnp.ones((10,), jnp.int32) f = lambda x, z: x.at[y].mul(z) x = jnp.ones((100,), jnp.float32) z = jnp.ones((10,), jnp.float32) with self.assertRaises(NotImplementedError, msg="scatter_mul gradients are only implemented if " "`unique_indices=True`"): jax.jvp(f, (x, z), (x, z)) def testSegmentSumBehavior(self): # testAdvancedIndexing compares against NumPy, and as a result doesn't check # repeated indices. This test is just a simple manual check, based on # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum data = np.array([5, 1, 7, 2, 3, 4, 1, 3], dtype=float) segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3]) ans = jnp.zeros_like(data, shape=np.max(segment_ids) + 1).at[segment_ids].add(data) expected = np.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) def testSegmentSum(self): data = jnp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = jnp.array([0, 0, 0, 1, 2, 2, 3, 3]) # test with explicit num_segments ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = jnp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test with explicit num_segments larger than the higher index. ans = ops.segment_sum(data, segment_ids, num_segments=5) expected = jnp.array([13, 2, 7, 4, 0]) self.assertAllClose(ans, expected, check_dtypes=False) # test without explicit num_segments ans = ops.segment_sum(data, segment_ids) expected = jnp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test with negative segment ids and segment ids larger than num_segments, # that will be wrapped with the `mod`. segment_ids = jnp.array([0, 4, 8, 1, 2, -6, -1, 3]) ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = jnp.array([5, 2, 3, 3]) self.assertAllClose(ans, expected, check_dtypes=False) # test with negative segment ids and without explicit num_segments # such as num_segments is defined by the smaller index. segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6]) ans = ops.segment_sum(data, segment_ids) expected = jnp.array([0, 0, 0, 13, 2, 7]) self.assertAllClose(ans, expected, check_dtypes=False) def testSegmentSumOutOfBounds(self): def fn(data, segment_ids): return jax.ops.segment_sum(data, segment_ids, num_segments).sum() data = np.array([0, 0], dtype=np.float32) num_segments = 2 segment_ids = np.array([2, 3]) val, grad = jax.value_and_grad(fn)(data, segment_ids) self.assertAllClose(val, np.array(0., np.float32)) self.assertAllClose(grad, np.array([0., 0.], np.float32)) @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(reducer=reducer, op=op, identity=identity)], dtype=[np.bool_], shape=[(8,), (7, 4), (6, 4, 2)], bucket_size=[None, 2], num_segments=[None, 1, 3], ) for reducer, op, identity in [ (ops.segment_min, np.minimum, True), (ops.segment_max, np.maximum, False), ])) def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity, num_segments, bucket_size): rng = jtu.rand_default(self.rng()) idx_rng = jtu.rand_int(self.rng(), low=-2, high=3) args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)] if np.issubdtype(dtype, np.integer): if np.isposinf(identity): identity = np.iinfo(dtype).max elif np.isneginf(identity): identity = np.iinfo(dtype).min jnp_fun = lambda data, segment_ids: reducer( data, segment_ids, num_segments=num_segments, bucket_size=bucket_size) def np_fun(data, segment_ids): size = num_segments if num_segments is not None else (segment_ids.max() + 1) out = np.full((size,) + shape[1:], identity, dtype) for i, val in zip(segment_ids, data): if 0 <= i < size: out[i] = op(out[i], val).astype(dtype) return out self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) if num_segments is not None: self._CompileAndCheck(jnp_fun, args_maker) @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(reducer=reducer, op=op, identity=identity)], dtype=default_dtypes, shape=[(8,), (7, 4), (6, 4, 2)], bucket_size=[None, 2], num_segments=[None, 1, 3], ) for reducer, op, identity in [ (ops.segment_sum, np.add, 0), (ops.segment_prod, np.multiply, 1), (ops.segment_min, np.minimum, float('inf')), (ops.segment_max, np.maximum, -float('inf')), ])) def testSegmentReduce(self, shape, dtype, reducer, op, identity, num_segments, bucket_size): rng = jtu.rand_default(self.rng()) idx_rng = jtu.rand_int(self.rng(), low=-2, high=3) args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)] if np.issubdtype(dtype, np.integer): if np.isposinf(identity): identity = np.iinfo(dtype).max elif np.isneginf(identity): identity = np.iinfo(dtype).min jnp_fun = lambda data, segment_ids: reducer( data, segment_ids, num_segments=num_segments, bucket_size=bucket_size) def np_fun(data, segment_ids): size = num_segments if num_segments is not None else (segment_ids.max() + 1) out = np.full((size,) + shape[1:], identity, dtype) for i, val in zip(segment_ids, data): if 0 <= i < size: out[i] = op(out[i], val).astype(dtype) return out self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) if num_segments is not None: self._CompileAndCheck(jnp_fun, args_maker) def testIndexDtypeError(self): # https://github.com/jax-ml/jax/issues/2795 jnp.array(1) # get rid of startup warning with self.assertNoWarnings(): jnp.zeros(5).at[::2].set(1) @jtu.sample_product( [dict(idx=idx, idx_type=idx_type) for idx, idx_type in [ ([0], "array"), ([0, 0], "array"), ([[0, 0]], "tuple"), ([0, [0, 1]], "tuple"), ([0, np.arange(2)], "tuple"), ([0, None], "tuple"), ([0, slice(None)], "tuple"), ] ], ) def testIndexSequenceDeprecation(self, idx, idx_type): normalize = {"array": np.array, "tuple": tuple}[idx_type] msg = {"array": ARRAY_MSG, "tuple": TUPLE_MSG}[idx_type] x = jnp.arange(6).reshape(3, 2) with self.assertRaisesRegex(TypeError, msg): x[idx] with self.assertNoWarnings(): x[normalize(idx)] with self.assertRaisesRegex(TypeError, msg): x.at[idx].set(0) with self.assertNoWarnings(): x.at[normalize(idx)].set(0) def testIndexedUpdateAliasingBug(self): # https://github.com/jax-ml/jax/issues/7461 fn = lambda x: x.at[1:].set(1 + x[:-1]) y = jnp.zeros(8) self.assertArraysEqual(fn(y), jax.jit(fn)(y)) def testScatterValuesCastToTargetDType(self): # https://github.com/jax-ml/jax/issues/15505 a = jnp.zeros(1, dtype=jnp.uint32) val = 2**32 - 1 # too large for int32 b = a.at[0].set(jnp.uint32(val)) self.assertEqual(int(b[0]), val) c = a.at[0].set(val) self.assertEqual(int(c[0]), val) def testGradOfVmapOfScatter(self): # Regression test for https://github.com/jax-ml/jax/issues/25878 def f(x, i): return x.at[i].get(mode='clip') x = jnp.array([1.0]) i = jnp.array([1]) # out-of-bound index expected = jnp.array([[1.0]]) self.assertArraysEqual(jax.jacrev(f)(x, i), expected) self.assertArraysEqual(jax.jacrev(jax.vmap(f, (None, 0)))(x, i), expected) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())