rocm_jax/tests/lax_numpy_indexing_test.py

1696 lines
65 KiB
Python
Raw Permalink Normal View History

# Copyright 2018 The JAX Authors.
2018-11-17 18:03:33 -08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import enum
2018-11-17 18:03:33 -08:00
from functools import partial
import itertools
2021-08-11 17:32:36 -04:00
import typing
from typing import Any
2018-11-17 18:03:33 -08:00
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
2018-11-17 18:03:33 -08:00
import jax
2022-04-20 16:04:12 -07:00
from jax import lax
from jax import numpy as jnp
from jax import ops
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lax import lax as lax_internal
from jax._src.util import NumpyComplexWarning
2018-11-17 18:03:33 -08:00
config.parse_flags_with_absl()
2018-11-17 18:03:33 -08:00
# We disable the whitespace continuation check in this file because otherwise it
# makes the test name formatting unwieldy.
# pylint: disable=bad-continuation
ARRAY_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[array\(seq\)\]"
TUPLE_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[tuple\(seq\)\]"
float_dtypes = jtu.dtypes.floating
default_dtypes = float_dtypes + jtu.dtypes.integer
all_dtypes = default_dtypes + jtu.dtypes.boolean
2018-11-17 18:03:33 -08:00
2021-08-11 17:32:36 -04:00
class IndexSpec(typing.NamedTuple):
shape: tuple[int, ...]
2021-08-11 17:32:36 -04:00
indexer: Any
out_shape: tuple[int, ...] | None = None
2018-11-17 18:03:33 -08:00
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
# TODO(mattjj,dougalm): add higher-order check
default_tol = 1e-6 if config.enable_x64.value else 1e-2
2018-11-17 18:03:33 -08:00
atol = atol or default_tol
rtol = rtol or default_tol
eps = eps or default_tol
jtu.check_jvp(f, partial(jax.jvp, f), args, atol, rtol, eps)
jtu.check_vjp(f, partial(jax.vjp, f), args, atol, rtol, eps)
2018-11-17 18:03:33 -08:00
STATIC_INDEXING_TESTS = [
2021-08-11 17:32:36 -04:00
("OneIntIndex", [
IndexSpec(shape=(3,), indexer=1, out_shape=()),
IndexSpec(shape=(3, 3), indexer=0, out_shape=(3,)),
IndexSpec(shape=(3, 4, 5), indexer=2, out_shape=(4, 5)),
IndexSpec(shape=(3,), indexer=-1, out_shape=()),
IndexSpec(shape=(3,), indexer=-2, out_shape=()),
]),
("TwoIntIndices", [
IndexSpec(shape=(3, 3), indexer=(2, 1), out_shape=()),
IndexSpec(shape=(3, 4, 5), indexer=(1, 2), out_shape=(5,)),
IndexSpec(shape=(3, 4, 5), indexer=(-1, 2), out_shape=(5,)),
]),
("ThreeIntIndices", [
IndexSpec(shape=(3, 4, 5), indexer=(1, 2, 3), out_shape=()),
]),
("OneSliceIndex", [
IndexSpec(shape=(10,), indexer=slice(1, 3), out_shape=(2,)),
IndexSpec(shape=(10,), indexer=slice(1, -1), out_shape=(8,)),
IndexSpec(shape=(10,), indexer=slice(None, -1), out_shape=(9,)),
IndexSpec(shape=(10,), indexer=slice(None, None, None), out_shape=(10,)),
IndexSpec(shape=(10, 8), indexer=slice(1, 3), out_shape=(2, 8)),
IndexSpec(shape=(10, 8), indexer=slice(1, None), out_shape=(9, 8)),
IndexSpec(shape=(10, 8), indexer=slice(None, 3), out_shape=(3, 8)),
IndexSpec(shape=(10, 8), indexer=slice(-3, None), out_shape=(3, 8)),
]),
("OneSliceIndexNegativeStride", [
IndexSpec(shape=(10,), indexer=slice(3, 1, -1), out_shape=(2,)),
IndexSpec(shape=(10,), indexer=slice(1, 8, -1), out_shape=(0,)),
IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)),
IndexSpec(shape=(10,), indexer=slice(None, None, -1), out_shape=(10,)),
IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1), out_shape=(2, 8)),
IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1), out_shape=(0, 8)),
IndexSpec(shape=(10, 8), indexer=slice(None, None, -1), out_shape=(10, 8)),
]),
("SliceIndexClamping", [
IndexSpec(shape=(10,), indexer=slice(2, 11, 1), out_shape=(8,)),
IndexSpec(shape=(10,), indexer=slice(11, 12, 1), out_shape=(0,)),
IndexSpec(shape=(10,), indexer=slice(-11, -2, 1), out_shape=(8,)),
IndexSpec(shape=(10,), indexer=slice(-2, -12, -1), out_shape=(9,)),
IndexSpec(shape=(10,), indexer=slice(12, -12, -1), out_shape=(10,)),
]),
2021-08-11 17:32:36 -04:00
("OneSliceIndexNonUnitStride", [
IndexSpec(shape=(10,), indexer=slice(0, 8, 2), out_shape=(4,)),
IndexSpec(shape=(10,), indexer=slice(0, 8, 3), out_shape=(3,)),
IndexSpec(shape=(10,), indexer=slice(1, 3, 2), out_shape=(1,)),
IndexSpec(shape=(10,), indexer=slice(1, None, 2), out_shape=(5,)),
IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)),
IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3), out_shape=(3, 8)),
IndexSpec(shape=(10, 8), indexer=slice(None, None, 2), out_shape=(5, 8)),
IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2), out_shape=(4, 8)),
IndexSpec(shape=(10, 8), indexer=slice(None, None, -2), out_shape=(5, 8)),
]),
("TwoSliceIndices", [
IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2)),
out_shape=(2, 2)),
IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2)),
out_shape=(9, 2)),
IndexSpec(shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2)),
out_shape=(10, 2)),
IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2)),
out_shape=(2, 2, 3)),
IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None)),
out_shape=(2, 8, 3)),
IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2)),
out_shape=(9, 2, 3)),
]),
("OneColonIndex", [
IndexSpec(shape=(3,), indexer=slice(None), out_shape=(3,)),
IndexSpec(shape=(3, 4), indexer=slice(None), out_shape=(3, 4)),
]),
("MultipleColonIndices", [
IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None)),
out_shape=(3, 4)),
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None)),
out_shape=(3, 4, 5)),
]),
("MixedSliceIndices", [
IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2)),
out_shape=(10, 2)),
IndexSpec(shape=(10, 4), indexer=(1, slice(None)),
out_shape=(4,)),
]),
("EllipsisIndex", [
IndexSpec(shape=(3,), indexer=Ellipsis, out_shape=(3,)),
IndexSpec(shape=(3, 4), indexer=Ellipsis, out_shape=(3, 4)),
IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis), out_shape=(4, 5)),
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3), out_shape=(3,)),
]),
("NoneIndex", [
IndexSpec(shape=(), indexer=None, out_shape=(1,)),
IndexSpec(shape=(), indexer=(None, None), out_shape=(1, 1)),
IndexSpec(shape=(), indexer=(Ellipsis, None), out_shape=(1,)),
IndexSpec(shape=(3,), indexer=None, out_shape=(1, 3)),
IndexSpec(shape=(3, 4), indexer=None, out_shape=(1, 3, 4)),
IndexSpec(shape=(3, 4), indexer=(Ellipsis, None), out_shape=(3, 4, 1)),
IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis), out_shape=(1, 4)),
IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis), out_shape=(1, 4, 5)),
]),
("EmptyIndex", [
IndexSpec(shape=(), indexer=(), out_shape=()),
IndexSpec(shape=(3,), indexer=(), out_shape=(3,)),
IndexSpec(shape=(3, 4), indexer=(), out_shape=(3, 4)),
]),
("TupleOfIntAndSliceAndIntArray", [
IndexSpec(shape=(3, 2, 3), indexer=(0, slice(None), np.arange(3)),
out_shape=(3, 2)),
IndexSpec(shape=(3, 2, 3), indexer=(np.int32(1), slice(None), np.arange(3)),
out_shape=(3, 2)),
IndexSpec(shape=(3, 2, 3), indexer=(np.array(2), slice(None), np.arange(3)),
out_shape=(3, 2)),
]),
]
STATIC_INDEXING_OUT_OF_BOUNDS_TESTS = [
("OneIntIndex", [
IndexSpec(shape=(3,), indexer=-4, out_shape=()),
IndexSpec(shape=(3, 3), indexer=3, out_shape=(3,)),
IndexSpec(shape=(3, 4, 5), indexer=4, out_shape=(4, 5)),
]),
("TwoIntIndices", [
IndexSpec(shape=(3, 3), indexer=(2, -4), out_shape=()),
IndexSpec(shape=(3, 4, 5), indexer=(3, 2), out_shape=()),
IndexSpec(shape=(3, 4, 5), indexer=(-4, 4), out_shape=(5,)),
]),
]
ADVANCED_INDEXING_TESTS = [
2021-08-11 17:32:36 -04:00
("One1DIntArrayIndex", [
IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)),
IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1]), out_shape=(3, 3)),
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1]),
out_shape=(4, 4, 5)),
IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)),
IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)),
IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32),
out_shape=(0,)),
]),
("One2DIntArrayIndex", [
IndexSpec(shape=(3,), indexer=np.array([[0, 0]]),out_shape=(1, 2)),
IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1], [0, 1, -1]]),
out_shape=(2, 3, 3)),
IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1], [-1, -2, 1, 0]]),
out_shape=(2, 4, 4, 5)),
]),
("Two1DIntArrayIndicesNoBroadcasting", [
IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])),
out_shape=(2,)),
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([0, 2, 0, 1]), np.array([-1, 0, -1, 2])),
out_shape=(4, 5)),
]),
("Two1DIntArrayIndicesWithBroadcasting", [
IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])),
out_shape=(1, 2)),
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([[0, 2, 0, 1]]), np.array([-1, 0, -1, 2])),
out_shape=(1, 4, 5)),
]),
("ArrayOfInts", [
IndexSpec(shape=(3,), indexer=np.array([0, 1, 0]), out_shape=(3,)),
IndexSpec(shape=(3, 4, 5), indexer=np.array([ 0, -1]), out_shape=(2, 4, 5)),
]),
("TupleOfListsOfPythonInts", [
IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)),
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]]),
out_shape=(2, 4, 5)),
]),
("TupleOfPythonIntsAndIntArrays", [
IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)),
IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0, 3]])),
out_shape=(1, 4)),
]),
("TupleOfListsOfPythonIntsAndIntArrays", [
IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])),
out_shape=(2, 5)),
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0, 3]])),
out_shape=(2, 4, 5)),
]),
]
ADVANCED_INDEXING_TESTS_NO_REPEATS = [
2021-08-11 17:32:36 -04:00
("One1DIntArrayIndex", [
IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)),
IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 0]), out_shape=(3, 3)),
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 1]),
out_shape=(3, 4, 5)),
IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)),
IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)),
IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)),
]),
("One2DIntArrayIndex", [
IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)),
IndexSpec(shape=(6, 6), indexer=np.array([[1, 2, 0], [3, 4, -1]]),
out_shape=(2, 3, 6)),
]),
("Two1DIntArrayIndicesNoBroadcasting", [
IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])),
out_shape=(2,)),
IndexSpec(shape=(4, 5, 6),
indexer=(np.array([0, 2, 1, 3]), np.array([-1, 0, -2, 1])),
out_shape=(4, 6)),
]),
("Two1DIntArrayIndicesWithBroadcasting", [
IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])),
out_shape=(1, 2)),
IndexSpec(shape=(4, 5, 6),
indexer=(np.array([[0, 2, -1, 1]]), np.array([-1, 0, -2, 2])),
out_shape=(1, 4, 6)),
]),
("ArrayOfInts", [
IndexSpec(shape=(3,), indexer=np.array([0, 2, 1]), out_shape=(3,)),
IndexSpec(shape=(3, 4, 5), indexer=np.array([ 0, -1]), out_shape=(2, 4, 5)),
]),
("TupleOfListsOfPythonInts", [
IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)),
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]]),
out_shape=(2, 3, 5)),
]),
("TupleOfPythonIntsAndIntArrays", [
IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)),
IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0]])),
out_shape=(1, 3)),
]),
("TupleOfListsOfPythonIntsAndIntArrays", [
IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])),
out_shape=(2, 5)),
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0]])),
out_shape=(2, 3, 5)),
]),
]
ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED = [
2021-08-11 17:32:36 -04:00
("One1DIntArrayIndex", [
IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)),
IndexSpec(shape=(3, 3), indexer=np.array([0, 1, 2]), out_shape=(3, 3)),
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 1, 2]),
out_shape=(3, 4, 5)),
IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)),
IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)),
IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)),
]),
("One2DIntArrayIndex", [
IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)),
IndexSpec(shape=(6, 6), indexer=np.array([[-1, 0, 1],
[ 2, 3, 4]]), out_shape=(2, 3, 6)),
]),
("Two1DIntArrayIndicesNoBroadcasting", [
IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])),
out_shape=(2,)),
IndexSpec(shape=(4, 5, 6),
indexer=(np.array([0, 1, 2, 3]), np.array([-2, -1, 0, 1])),
out_shape=(4, 6)),
]),
("Two1DIntArrayIndicesWithBroadcasting", [
IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])),
out_shape=(1, 2)),
IndexSpec(shape=(4, 5, 6),
indexer=(np.array([[-1, 0, 1, 2]]), np.array([-2, -1, 0, 2])),
out_shape=(1, 4, 6)),
]),
("TupleOfListsOfPythonInts", [
IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)),
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[0, 2, 3]]),
out_shape=(2, 3, 5)),
]),
("TupleOfPythonIntsAndIntArrays", [
IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)),
IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[0, 2, 3]])),
out_shape=(1, 3)),
]),
("TupleOfListsOfPythonIntsAndIntArrays", [
IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])),
out_shape=(2, 5)),
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[0, 2, 3]])),
out_shape=(2, 3, 5)),
]),
]
2021-08-11 17:32:36 -04:00
MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [
2021-08-11 17:32:36 -04:00
("SlicesAndOneIntArrayIndex", [
IndexSpec(shape=(2, 3), indexer=(np.array([0, 1]), slice(1, 2)),
out_shape=(2, 1)),
IndexSpec(shape=(2, 3), indexer=(slice(0, 2), np.array([0, 2])),
out_shape=(2, 2)),
IndexSpec(shape=(3, 4, 5),
indexer=(Ellipsis, np.array([0, 2]), slice(None)),
out_shape=(3, 2, 5)),
IndexSpec(shape=(3, 4, 5),
indexer=(Ellipsis, np.array([[0, 2], [1, 3]]), slice(None)),
out_shape=(3, 2, 2, 5)),
]),
("SlicesAndTwoIntArrayIndices", [
IndexSpec(shape=(3, 4, 5),
indexer=(Ellipsis, np.array([0, 2]), np.array([-1, 2])),
out_shape=(3, 2)),
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([0, 2]), Ellipsis, np.array([-1, 2])),
out_shape=(2, 4)),
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([0, 2]), np.array([-1, 2]), Ellipsis),
out_shape=(2, 5)),
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([0, 2]), np.array([-1, 2]), slice(1, 3)),
out_shape=(2, 2)),
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([0, 2]), slice(1, 3), np.array([-1, 2])),
out_shape=(2, 2)),
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([ 0, 2, -2]), slice(None, None, 2),
np.array([-1, 2, 1])),
out_shape=(3, 2)),
]),
("NonesAndIntArrayIndices", [
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([0, 2]), None, np.array([-1, 2])),
out_shape=(2, 1, 5)),
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([0, 2]), None, None, np.array([-1, 2])),
out_shape=(2, 1, 1, 5)),
IndexSpec(shape=(3, 4, 5),
indexer=(Ellipsis, np.array([0, 2]), None, None,
np.array([-1, 2])),
out_shape=(2, 3, 1, 1)),
]),
("IntArrayWithInt32Type", [
IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)),
out_shape=(3,)),
]),
("EllipsisWithArrayIndices", [
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 1]), ..., np.array([0, 1])),
out_shape=(2, 4)),
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), np.array([0, 1]), ..., np.array([0, 1])),
out_shape=(2, 3)),
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), ..., np.array([0, 1]), np.array([0, 1])),
out_shape=(3, 2)),
]),
]
2021-08-11 17:32:36 -04:00
MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [
2021-08-11 17:32:36 -04:00
("SlicesAndOneIntArrayIndex", [
IndexSpec(shape=(3, 4, 5),
indexer=(Ellipsis, np.array([[0, 2], [1, 1]]), slice(None)),
out_shape=(3, 2, 2, 5)),
]),
("SlicesAndTwoIntArrayIndices", [
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([ 0, 2, -2]), slice(None, None, 2),
np.array([-1, 2, -1])),
out_shape=(3, 2)),
IndexSpec(shape=(3, 4, 5),
indexer=(np.array([[0, 2], [2, 0]]), Ellipsis,
np.array([[1, 0], [1, 0]])),
out_shape=(2, 2, 4)),
]),
]
MODES = ["clip", "drop", "promise_in_bounds"]
2018-11-17 18:03:33 -08:00
class IndexingTest(jtu.JaxTestCase):
"""Tests for Numpy indexing translation rules."""
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in STATIC_INDEXING_TESTS
for shape, indexer, _ in index_specs],
dtype=all_dtypes
)
def testStaticIndexing(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda x: np.asarray(x)[indexer]
jnp_fun = lambda x: jnp.asarray(x)[indexer]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2021-08-11 17:32:36 -04:00
# Tests x.at[...].get(...) as well.
jnp_fun = lambda x: jnp.asarray(x).at[indexer].get()
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testStaticIndexingWithJaxArray(self):
shape = (10,)
indexer = slice(jnp.array(2, dtype=np.int32),
np.array(11, dtype=np.int32),
jnp.array(1, dtype=np.int32))
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, np.int32)]
np_fun = lambda x: np.asarray(x)[indexer]
jnp_fun = lambda x: jnp.asarray(x)[indexer]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
# Tests x.at[...].get(...) as well.
jnp_fun = lambda x: jnp.asarray(x).at[indexer].get()
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
funcname=["negative", "sin", "cos", "square", "sqrt", "log", "exp"],
)
def testIndexApply(self, funcname, size=10, dtype='float32'):
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), -size, size)
np_func = getattr(np, funcname)
jnp_func = getattr(jnp, funcname)
@jtu.ignore_warning(category=RuntimeWarning)
def np_op(x, idx):
y = x.copy()
np_func.at(y, idx)
return y
def jnp_op(x, idx):
return jnp.asarray(x).at[idx].apply(jnp_func)
# Test with traced integer index
args_maker = lambda: [rng(size, dtype), idx_rng(size, int)]
tol = (
5e-5
if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp")
else None
)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=tol)
self._CompileAndCheck(jnp_op, args_maker)
# Test with slice index
idx = slice(1, 5)
np_op_idx = partial(np_op, idx=idx)
jnp_op_idx = partial(jnp_op, idx=idx)
args_maker = lambda: [rng(size, dtype)]
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=tol,
rtol=tol)
self._CompileAndCheck(jnp_op_idx, args_maker)
def testIndexApplyBatchingBug(self):
# https://github.com/jax-ml/jax/issues/16655
arr = jnp.array([[1, 2, 3, 4, 5, 6]])
ind = jnp.array([3])
func = lambda a, i: a.at[i].apply(lambda x: x - 1)
expected = jnp.array(list(map(func, arr, ind)))
out = jax.vmap(func)(arr, ind)
self.assertArraysEqual(out, expected)
def testIndexUpdateScalarBug(self):
# https://github.com/jax-ml/jax/issues/14923
a = jnp.arange(10.)
out = a.at[0].apply(jnp.cos)
self.assertArraysEqual(out, a.at[0].set(1))
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, mode=mode)
for mode in MODES
for name, index_specs in (
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
for shape, indexer, _ in index_specs
],
dtype=float_dtypes,
)
def testStaticIndexingGrads(self, name, shape, dtype, indexer, mode):
rng = jtu.rand_default(self.rng())
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
2018-11-17 18:03:33 -08:00
arg = rng(shape, dtype)
2021-08-11 17:32:36 -04:00
# Use an arbitrary finite fill_value, since NaNs won't work in a numerical
# gradient test.
fun = lambda x: jnp.asarray(x).at[indexer].get(mode=mode, fill_value=7)**2
2018-11-17 18:03:33 -08:00
check_grads(fun, (arg,), 2, tol, tol, tol)
def _ReplaceSlicesWithTuples(self, idx):
"""Helper method to replace slices with tuples for dynamic indexing args."""
if isinstance(idx, slice):
triple = idx.start, idx.stop, idx.step
isnone = [i for i, elt in enumerate(triple) if elt is None]
zeros = itertools.repeat(0)
nones = itertools.repeat(None)
out = util.subvals(triple, zip(isnone, zeros))
return out, lambda out: slice(*util.subvals(out, zip(isnone, nones)))
2018-11-17 18:03:33 -08:00
elif isinstance(idx, (tuple, list)) and idx:
t = type(idx)
elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts)))
else:
return idx, lambda x: x
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
2018-11-17 18:03:33 -08:00
for name, index_specs in [
("OneSliceIndex",
[IndexSpec(shape=(5,), indexer=slice(1, 3)),
IndexSpec(shape=(5, 4), indexer=slice(1, 3))]),
("TwoSliceIndices",
[IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]),
("NonUnitStrides", [
IndexSpec(shape=(3,), indexer=slice(None, None, -1)),
IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
]),
("OnlyStartOrStopDynamic", [
IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
]),
2018-11-17 18:03:33 -08:00
]
2021-08-11 17:32:36 -04:00
for shape, indexer, _ in index_specs
],
dtype=all_dtypes,
)
def testDynamicIndexingWithSlicesErrors(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
2018-11-17 18:03:33 -08:00
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
@jax.jit
2018-11-17 18:03:33 -08:00
def fun(x, unpacked_indexer):
indexer = pack_indexer(unpacked_indexer)
return x[indexer]
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
self.assertRaises(IndexError, lambda: fun(*args_maker()))
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
2018-11-17 18:03:33 -08:00
for name, index_specs in [
("OneIntIndex",
[IndexSpec(shape=(3,), indexer=1),
IndexSpec(shape=(3, 3), indexer=0),
IndexSpec(shape=(3, 4, 5), indexer=2),
IndexSpec(shape=(3,), indexer=-1),
IndexSpec(shape=(3,), indexer=-2)]),
("TwoIntIndices",
[IndexSpec(shape=(3, 3), indexer=(2, 1)),
IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]),
("ThreeIntIndices",
[IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
]
2021-08-11 17:32:36 -04:00
for shape, indexer, _ in index_specs
],
dtype=all_dtypes,
)
def testDynamicIndexingWithIntegers(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
2018-11-17 18:03:33 -08:00
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
def np_fun(x, unpacked_indexer):
2018-11-17 18:03:33 -08:00
indexer = pack_indexer(unpacked_indexer)
return np.asarray(x)[indexer]
def jnp_fun(x, unpacked_indexer):
indexer = pack_indexer(unpacked_indexer)
return jnp.array(x)[indexer]
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
2018-11-17 18:03:33 -08:00
for name, index_specs in [
("OneIntIndex",
[IndexSpec(shape=(3,), indexer=1),
IndexSpec(shape=(3, 3), indexer=0),
IndexSpec(shape=(3, 4, 5), indexer=2),
IndexSpec(shape=(3,), indexer=-1),
IndexSpec(shape=(3,), indexer=-2),
]),
("TwoIntIndices",
[IndexSpec(shape=(3, 3), indexer=(2, 1)),
IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)),
]),
("ThreeIntIndices",
[IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
]
2021-08-11 17:32:36 -04:00
for shape, indexer, _ in index_specs
],
dtype=float_dtypes,
)
def testDynamicIndexingWithIntegersGrads(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
2018-11-17 18:03:33 -08:00
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
@jax.jit
2018-11-17 18:03:33 -08:00
def fun(unpacked_indexer, x):
indexer = pack_indexer(unpacked_indexer)
return x[indexer]
arr = rng(shape, dtype)
check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol)
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in ADVANCED_INDEXING_TESTS
2021-08-11 17:32:36 -04:00
for shape, indexer, _ in index_specs
],
dtype=all_dtypes,
)
def testAdvancedIntegerIndexing(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(shape, dtype), indexer]
np_fun = lambda x, idx: np.asarray(x)[idx]
jnp_fun = lambda x, idx: jnp.asarray(x)[idx]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(dtype=jtu.dtypes.unsigned + jtu.dtypes.integer)
2022-04-20 16:04:12 -07:00
def testIndicesNormalizationByType(self, dtype):
x = jnp.arange(10)
jaxpr = jax.make_jaxpr(x.__getitem__)(jnp.arange(3, dtype=dtype))
primitives = [eqn.primitive for eqn in jaxpr.eqns]
if np.issubdtype(dtype, np.unsignedinteger):
# Unsigned integers should not require lt, add, and select.
self.assertEqual(primitives, [lax.convert_element_type_p, lax.broadcast_in_dim_p, lax.gather_p])
else:
# May or may not contain convert_element_type.
self.assertIn(len(primitives), [5, 6])
self.assertEqual(primitives[:3], [lax.lt_p, lax.add_p, lax.select_n_p])
self.assertEqual(primitives[-2:], [lax.broadcast_in_dim_p, lax.gather_p])
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
2018-11-17 18:03:33 -08:00
for name, index_specs in [
("One1DIntArrayIndex",
[IndexSpec(shape=(3,), indexer=np.array([0, 1])),
IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1])),
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1])),
IndexSpec(shape=(3,), indexer=np.array([-1, 1])),
IndexSpec(shape=(3,), indexer=np.array([-2, -1])),
2018-11-17 18:03:33 -08:00
]),
("One2DIntArrayIndex",
[IndexSpec(shape=(3,), indexer=np.array([[0, 0]])),
IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1],
2018-11-17 18:03:33 -08:00
[0, 1, -1]])),
IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1],
2018-11-17 18:03:33 -08:00
[-1, -2, 1, 0]])),
]),
("Two1DIntArrayIndicesNoBroadcasting",
[IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]),
np.array([1, 2]))),
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2, 0, 1]),
np.array([-1, 0, -1, 2]))),
2018-11-17 18:03:33 -08:00
]),
("Two1DIntArrayIndicesWithBroadcasting",
[IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]),
np.array([1, 2]))),
IndexSpec(shape=(3, 4, 5), indexer=(np.array([[0, 2, 0, 1]]),
np.array([-1, 0, -1, 2]))),
2018-11-17 18:03:33 -08:00
]),
("TupleOfPythonIntsAndIntArrays",
[IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1]))),
IndexSpec(shape=(3, 4, 5), indexer=(0, 1,
np.array([[2, 3, 0, 3]]))),
2018-11-17 18:03:33 -08:00
]),
("TupleOfListsOfPythonIntsAndIntArrays",
[IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0]))),
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]],
np.array([[2, 3, 0, 3]]))),
2018-11-17 18:03:33 -08:00
]),
]
2021-08-11 17:32:36 -04:00
for shape, indexer, _ in index_specs
],
dtype=float_dtypes,
)
def testAdvancedIntegerIndexingGrads(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
2018-11-17 18:03:33 -08:00
arg = rng(shape, dtype)
fun = lambda x: jnp.asarray(x)[indexer]
check_grads(fun, (arg,), 2, tol, tol, eps=1.)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
2021-08-11 17:32:36 -04:00
for shape, indexer, _ in index_specs
],
dtype=all_dtypes,
)
def testMixedAdvancedIntegerIndexing(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
indexer_with_dummies = [e if isinstance(e, np.ndarray) else ()
2018-11-17 18:03:33 -08:00
for e in indexer]
substitutes = [(i, e) for i, e in enumerate(indexer)
if not isinstance(e, np.ndarray)]
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(shape, dtype), indexer_with_dummies]
def jnp_fun(x, indexer_with_dummies):
idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
return jnp.asarray(x)[idx]
2018-11-17 18:03:33 -08:00
def np_fun(x, indexer_with_dummies):
idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
return np.asarray(x)[idx]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
def testAdvancedIndexingManually(self):
x = self.rng().randn(3, 4, 5)
index_array = np.array([0, 2, -1, 0])
2018-11-17 18:03:33 -08:00
op = lambda x, index_array: x[..., index_array, :]
cop = jax.jit(op)
2018-11-17 18:03:33 -08:00
a1 = op(x, index_array)
a2 = cop(x, index_array)
self.assertAllClose(a1, a2)
2018-11-17 18:03:33 -08:00
op = lambda x, index_array: x[..., index_array, :, index_array, None]
cop = jax.jit(op)
2018-11-17 18:03:33 -08:00
a1 = op(x, index_array)
a2 = cop(x, index_array)
self.assertAllClose(a1, a2)
2018-11-17 18:03:33 -08:00
op = lambda x, index_array: x[index_array, ..., index_array[:, None], None]
cop = jax.jit(op)
2018-11-17 18:03:33 -08:00
a1 = op(x, index_array)
a2 = cop(x, index_array)
self.assertAllClose(a1, a2)
2018-11-17 18:03:33 -08:00
def testUnpacking(self):
def foo(x):
a, b, c = x
return a + b + c
cfoo = jax.jit(foo)
2018-11-17 18:03:33 -08:00
a1 = foo(np.arange(3))
a2 = cfoo(np.arange(3))
2018-11-17 18:03:33 -08:00
self.assertAllClose(a1, a2)
2018-11-17 18:03:33 -08:00
def testBooleanIndexingArray1D(self):
idx = np.array([True, True, False])
x = jax.device_put(np.arange(3))
ans = x[idx]
expected = np.arange(3)[idx]
self.assertAllClose(ans, expected, check_dtypes=False)
def testBooleanIndexingList1D(self):
idx = [True, True, False]
x = jax.device_put(np.arange(3))
with self.assertRaisesRegex(TypeError, ARRAY_MSG):
x[idx]
def testBooleanIndexingArray2DBroadcast(self):
idx = np.array([True, True, False, True])
x = np.arange(8).reshape(4, 2)
ans = jax.device_put(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)
def testBooleanIndexingList2DBroadcast(self):
idx = [True, True, False, True]
x = np.arange(8).reshape(4, 2)
with self.assertRaisesRegex(TypeError, ARRAY_MSG):
jax.device_put(x)[idx]
def testBooleanIndexingArray2D(self):
idx = np.array([[True, False],
[False, True],
[False, False],
[True, True]])
x = np.arange(8).reshape(4, 2)
ans = jax.device_put(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)
def testBoolean1DIndexingWithEllipsis(self):
# Regression test for https://github.com/jax-ml/jax/issues/8412
x = np.arange(24).reshape(4, 3, 2)
idx = (..., np.array([True, False]))
ans = jnp.array(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)
def testBoolean1DIndexingWithEllipsis2(self):
# Regression test for https://github.com/jax-ml/jax/issues/9050
x = np.arange(3)
idx = (..., np.array([True, False, True]))
ans = jnp.array(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)
def testBoolean1DIndexingWithEllipsis3(self):
x = np.arange(6).reshape(2, 3)
idx = (0, ..., np.array([True, False, True]))
ans = jnp.array(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)
def testBoolean2DIndexingWithEllipsis(self):
x = np.arange(24).reshape(4, 3, 2)
idx = (..., np.array([[True, False], [True, False], [False, False]]))
ans = jnp.array(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)
def testBoolean1DIndexingWithTrailingEllipsis(self):
x = np.arange(24).reshape(4, 3, 2)
idx = (np.array([True, False, True, False]), ...)
ans = jnp.array(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)
def testBooleanIndexingDynamicShapeError(self):
x = np.zeros(3)
i = np.array([True, True, False])
self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i))
def testIssue187(self):
x = jnp.ones((5, 5))
x[[0, 2, 4], [0, 2, 4]] # doesn't crash
x = np.arange(25).reshape((5, 5))
ans = jax.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
expected = x[[0, 2, 4], [0, 2, 4]]
self.assertAllClose(ans, expected, check_dtypes=False)
def testJVPOfGradOfIndexing(self):
# Should return a value, even though we didn't pass a symbolic zero as the
# index tangent.
x = jnp.ones((3, 4), jnp.float32)
i = jnp.ones((3,), jnp.int32)
f = lambda x, i: jnp.sum(x[i])
primals, tangents = jax.jvp(jax.grad(f), (x, i),
(x, np.zeros(i.shape, dtypes.float0)))
expected = np.broadcast_to(
np.array([0, 3, 0], dtype=np.float32)[:, None], (3, 4))
self.assertAllClose(expected, primals)
self.assertAllClose(np.zeros_like(x), tangents)
def testSimpleIndexingUsesSlice(self):
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p)
jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
# Indexing with `Ellipsis` is not lowered to `gather`.
jaxpr = jax.make_jaxpr(lambda x: x[..., 0])(jnp.ones((3, 4, 5)))
self.assertLen((jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
# Simple reverses lower to lax.rev_p
jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)
# Non-static indices produce a dynamic slice
jaxpr = jax.make_jaxpr(lambda x, i: x[i])(jnp.ones((4,)), 2)
self.assertEqual(len(jaxpr.jaxpr.eqns), 6)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
def testTrivialGatherIsntGenerated(self):
# https://github.com/jax-ml/jax/issues/1621
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertNotIn('gather', str(jaxpr))
jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)
def testOOBEmptySlice(self):
x = jnp.arange(4, dtype='float32')
self.assertArraysEqual(x[1:0], jnp.empty(0, dtype='float32'))
self.assertArraysEqual(x[-2:-10], jnp.empty(0, dtype='float32'))
self.assertArraysEqual(x[5:10], jnp.empty(0, dtype='float32'))
x = jnp.arange(6, dtype='float32').reshape(2, 3)
self.assertArraysEqual(x[-1:-4], jnp.empty((0, 3), dtype='float32'))
self.assertArraysEqual(x[:, 3:2], jnp.empty((2, 0), dtype='float32'))
def testIndexingEmptyDimension(self):
# Issue 2671: XLA error when indexing into dimension of size 0
x = jnp.ones((2, 0))
# The following work, even on axis 1 of size 0
with jax.numpy_rank_promotion('allow'):
_ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2]
with self.assertRaisesRegex(IndexError,
"index .* is out of bounds for axis .* with size 0"):
_ = np.ones((2, 0))[0, 0] # The numpy error
with self.assertRaisesRegex(IndexError,
"index is out of bounds for axis .* with size 0"):
_ = x[0, 0] # JAX indexing
with self.assertRaisesRegex(IndexError,
"index is out of bounds for axis .* with size 0"):
jax.jit(lambda i: x[0, i])(0) # JAX indexing under jit
def testBooleanIndexingWithEmptyResult(self):
# based on a TensorFlow Probability test that started failing after #1622
x = jnp.array([-1])
mask = jnp.array([False])
ans = x[mask] # doesn't crash
expected = np.array([-1])[np.array([False])]
self.assertAllClose(ans, expected, check_dtypes=False)
2021-08-03 09:51:52 -07:00
def testBooleanIndexingShapeMismatch(self):
# Regression test for https://github.com/jax-ml/jax/issues/7329
2021-08-03 09:51:52 -07:00
x = jnp.arange(4)
idx = jnp.array([True, False])
with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"):
x[idx]
def testBooleanIndexingWithNone(self):
# Regression test for https://github.com/jax-ml/jax/issues/18542
x = jnp.arange(6).reshape(2, 3)
idx = (None, jnp.array([True, False]))
ans = x[idx]
expected = jnp.arange(3).reshape(1, 1, 3)
self.assertAllClose(ans, expected)
def testBooleanIndexingWithNoneAndEllipsis(self):
# Regression test for https://github.com/jax-ml/jax/issues/18542
x = jnp.arange(6).reshape(2, 3)
mask = jnp.array([True, False, False])
ans = x[None, ..., mask]
expected = jnp.array([0, 3]).reshape(1, 2, 1)
self.assertAllClose(ans, expected)
def testBooleanIndexingWithEllipsisAndNone(self):
# Regression test for https://github.com/jax-ml/jax/issues/18542
x = jnp.arange(6).reshape(2, 3)
mask = jnp.array([True, False, False])
ans = x[..., None, mask]
expected = jnp.array([0, 3]).reshape(2, 1, 1)
self.assertAllClose(ans, expected)
2021-08-03 09:51:52 -07:00
def testNontrivialBooleanIndexing(self):
# Test nontrivial corner case in boolean indexing shape validation
rng = jtu.rand_default(self.rng())
index = (rng((2, 3), np.bool_), rng((6,), np.bool_))
args_maker = lambda: [rng((2, 3, 6), np.int32)]
np_fun = lambda x: np.asarray(x)[index]
jnp_fun = lambda x: jnp.asarray(x)[index]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2024-08-06 09:56:03 -07:00
@parameterized.parameters(
[(3,), (0,)],
[(3, 4), (0,)],
[(3, 4), (0, 4)],
[(3, 4), (3, 0)],
[(3, 4, 5), (3, 0)],
)
def testEmptyBooleanIndexing(self, x_shape, m_shape):
# Regression test for https://github.com/jax-ml/jax/issues/22886
2024-08-06 09:56:03 -07:00
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(x_shape, np.int32), np.empty(m_shape, dtype=bool)]
np_fun = lambda x, m: np.asarray(x)[np.asarray(m)]
jnp_fun = lambda x, m: jnp.asarray(x)[jnp.asarray(m)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
shape=[(2, 3, 4, 5)],
idx=[
np.index_exp[True],
np.index_exp[False],
np.index_exp[..., True],
np.index_exp[..., False],
np.index_exp[0, :2, True],
np.index_exp[0, :2, False],
np.index_exp[:2, 0, True],
np.index_exp[:2, 0, False],
np.index_exp[:2, np.array([0, 2]), True],
np.index_exp[np.array([1, 0]), :, True],
np.index_exp[True, :, True, :, np.array(True)],
]
)
def testScalarBooleanIndexing(self, shape, idx):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, np.int32)]
np_fun = lambda x: np.asarray(x)[idx]
jnp_fun = lambda x: jnp.asarray(x)[idx]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
shape=[(2, 3, 4, 5)],
update_ndim=[0, 1, 2],
idx=[
np.index_exp[True],
np.index_exp[False],
np.index_exp[..., True],
np.index_exp[..., False],
np.index_exp[0, :2, True],
np.index_exp[0, :2, False],
np.index_exp[:2, 0, True],
np.index_exp[:2, 0, False],
np.index_exp[:2, np.array([0, 2]), True],
np.index_exp[np.array([1, 0]), :, True],
np.index_exp[True, :, True, :, np.array(True)],
]
)
def testScalarBoolUpdate(self, shape, idx, update_ndim):
update_shape = np.zeros(shape)[idx].shape[-update_ndim:]
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, np.int32), rng(update_shape, np.int32)]
def np_fun(x, update):
x = np.array(x, copy=True)
x[idx] = update
return x
jnp_fun = lambda x, update: jnp.asarray(x).at[idx].set(update)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
def testFloatIndexingError(self):
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros(2)[0.]
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros((2, 2))[(0, 0.)]
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros((2, 2))[(0, 0.)]
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jax.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros(2).at[0.].add(1.)
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros(2).at[0.].set(1.)
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros((2, 2))[jnp.arange(2), 1.0]
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros((2, 2))[jnp.arange(2), 1 + 1j]
2023-03-20 08:55:16 -07:00
def testStrIndexingError(self):
msg = "JAX does not support string indexing"
with self.assertRaisesRegex(TypeError, msg):
jnp.zeros(2)['abc']
with self.assertRaisesRegex(TypeError, msg):
jnp.zeros((2, 3))[:, 'abc']
2023-03-20 08:55:16 -07:00
def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245
2021-08-11 17:32:36 -04:00
x = jnp.arange(5, dtype=jnp.int32) + 1
self.assertAllClose(x, x[:10])
idx = jnp.array([-10, -6, -5, -4, 0, 3, 4, 5, 6, 100])
self.assertArraysEqual(
x.at[idx].get(mode="clip"),
jnp.array([1, 1, 1, 2, 1, 4, 5, 5, 5, 5], jnp.int32))
nan = np.nan
self.assertArraysEqual(
x.astype(jnp.float32).at[idx].get(mode="fill"),
jnp.array([nan, nan, 1, 2, 1, 4, 5, nan, nan, nan], jnp.float32))
imin = np.iinfo(np.int32).min
self.assertArraysEqual(
x.at[idx].get(mode="fill"),
jnp.array([imin, imin, 1, 2, 1, 4, 5, imin, imin, imin], jnp.int32))
umax = np.iinfo(np.uint32).max
self.assertArraysEqual(
x.astype(np.uint32).at[idx].get(mode="fill"),
jnp.array([umax, umax, 1, 2, 1, 4, 5, umax, umax, umax], jnp.uint32))
self.assertArraysEqual(
x.at[idx].get(mode="fill", fill_value=7),
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))
def testIndexingWeakTypes(self):
x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True)
a = x.at[0].set(1.0)
self.assertEqual(a.dtype, x.dtype)
self.assertTrue(dtypes.is_weakly_typed(a))
b = x.at[0].add(1.0)
self.assertEqual(b.dtype, x.dtype)
self.assertTrue(dtypes.is_weakly_typed(b))
c = x.at[0].mul(1.0)
self.assertEqual(c.dtype, x.dtype)
self.assertTrue(dtypes.is_weakly_typed(c))
def testIndexingTypePromotion(self):
def _check(x_type, y_type):
x = jnp.arange(5, dtype=x_type)
y = y_type(0)
out = x.at[0].set(y)
self.assertEqual(x.dtype, out.dtype)
2023-08-07 19:08:41 +02:00
@jtu.ignore_warning(category=NumpyComplexWarning,
message="Casting complex values to real")
def _check_warns(x_type, y_type, msg):
with self.assertWarnsRegex(FutureWarning, msg):
_check(x_type, y_type)
def _check_raises(x_type, y_type, msg):
with self.assertRaisesRegex(ValueError, msg):
_check(x_type, y_type)
# Matching dtypes are always OK
_check(jnp.int32, jnp.int32)
_check(jnp.float32, jnp.float32)
_check(jnp.complex64, jnp.complex64)
# Weakly-typed y values promote.
_check(jnp.int32, int)
_check(jnp.float32, int)
_check(jnp.float32, float)
_check(jnp.complex64, int)
_check(jnp.complex64, float)
_check(jnp.complex64, complex)
# in standard promotion mode, strong types can promote.
msg = "scatter inputs have incompatible types"
with jax.numpy_dtype_promotion('standard'):
_check(jnp.int32, jnp.int16)
_check(jnp.float32, jnp.float16)
_check(jnp.float32, jnp.int32)
_check(jnp.complex64, jnp.int32)
_check(jnp.complex64, jnp.float32)
# TODO(jakevdp): make these _check_raises
_check_warns(jnp.int16, jnp.int32, msg)
_check_warns(jnp.int32, jnp.float32, msg)
_check_warns(jnp.int32, jnp.complex64, msg)
_check_warns(jnp.float16, jnp.float32, msg)
_check_warns(jnp.float32, jnp.complex64, msg)
# in strict promotion mode, strong types do not promote.
msg = "Input dtypes .* have no available implicit dtype promotion path"
with jax.numpy_dtype_promotion('strict'):
_check_raises(jnp.int32, jnp.int16, msg)
_check_raises(jnp.float32, jnp.float16, msg)
_check_raises(jnp.float32, jnp.int32, msg)
_check_raises(jnp.complex64, jnp.int32, msg)
_check_raises(jnp.complex64, jnp.float32, msg)
_check_raises(jnp.int16, jnp.int32, msg)
_check_raises(jnp.int32, jnp.float32, msg)
_check_raises(jnp.int32, jnp.complex64, msg)
_check_raises(jnp.float16, jnp.float32, msg)
_check_raises(jnp.float32, jnp.complex64, msg)
def testWrongNumberOfIndices(self):
with self.assertRaisesRegex(
IndexError,
"Too many indices: 0-dimensional array indexed with 1 regular index."):
jnp.array(1)[0]
with self.assertRaisesRegex(
IndexError,
"Too many indices: 1-dimensional array indexed with 2 regular indices."):
jnp.zeros(3)[:, 5]
def _broadcastable_shapes(shape):
"""Returns all shapes that broadcast to `shape`."""
def f(rshape):
yield []
if rshape:
for s in f(rshape[1:]):
yield rshape[0:1] + s
if rshape[0] != 1:
for s in f(rshape[1:]):
yield [1] + s
for x in f(list(reversed(shape))):
yield list(reversed(x))
# TODO(jakevdp): move this implementation to jax.dtypes & use in scatter?
def _can_cast(from_, to):
with jax.numpy_dtype_promotion('standard'):
return lax.dtype(to) == dtypes.result_type(from_, to)
def _compatible_dtypes(op, dtype, inexact=False):
if op == UpdateOps.ADD or op == UpdateOps.SUB:
return [dtype]
elif inexact:
return [dt for dt in float_dtypes if _can_cast(dt, dtype)]
else:
return [dt for dt in all_dtypes if _can_cast(dt, dtype)]
class UpdateOps(enum.Enum):
UPDATE = 0
ADD = 1
SUB = 2
MUL = 3
DIV = 4
POW = 5
MIN = 6
MAX = 7
def np_fn(op, indexer, x, y):
x = x.copy()
x[indexer] = {
UpdateOps.UPDATE: lambda: y,
UpdateOps.ADD: lambda: x[indexer] + y,
UpdateOps.SUB: lambda: x[indexer] - y,
UpdateOps.MUL: lambda: x[indexer] * y,
UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)(
lambda: x[indexer] / y.astype(x.dtype)),
UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)(
lambda: x[indexer] ** y.astype(x.dtype)),
UpdateOps.MIN: lambda: np.minimum(x[indexer], y),
UpdateOps.MAX: lambda: np.maximum(x[indexer], y),
}[op]()
return x
def jax_fn(op, indexer, x, y, indices_are_sorted=False,
2021-08-11 17:32:36 -04:00
unique_indices=False, mode=None):
2020-04-11 20:54:04 +00:00
x = jnp.array(x)
return {
UpdateOps.UPDATE: x.at[indexer].set,
UpdateOps.ADD: x.at[indexer].add,
UpdateOps.SUB: x.at[indexer].subtract,
UpdateOps.MUL: x.at[indexer].multiply,
UpdateOps.DIV: x.at[indexer].divide,
UpdateOps.POW: x.at[indexer].power,
UpdateOps.MIN: x.at[indexer].min,
UpdateOps.MAX: x.at[indexer].max,
}[op](y, indices_are_sorted=indices_are_sorted,
2021-08-11 17:32:36 -04:00
unique_indices=unique_indices, mode=mode)
2020-04-11 20:54:04 +00:00
def dtypes(op):
if op == UpdateOps.UPDATE:
return all_dtypes
elif op == UpdateOps.DIV or op == UpdateOps.POW:
return jtu.dtypes.inexact
else:
return default_dtypes
def _update_tol(op):
if op == UpdateOps.POW:
f32_tol = 2e-4 if jtu.test_device_matches(["tpu"]) else 1e-5
tol = {np.float32: f32_tol, np.complex64: f32_tol, np.complex128: 1e-14}
else:
tol = {np.complex128: 1e-14}
return tol
2021-08-11 17:32:36 -04:00
class IndexedUpdateTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
for name, index_specs in STATIC_INDEXING_TESTS
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in UpdateOps
for dtype in UpdateOps.dtypes(op)
for update_dtype in _compatible_dtypes(op, dtype)
],
mode=MODES,
)
def testStaticIndexing(self, name, shape, dtype, update_shape, update_dtype,
2021-08-11 17:32:36 -04:00
indexer, op, mode):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode)
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker)
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in UpdateOps
for dtype in UpdateOps.dtypes(op)
for update_dtype in _compatible_dtypes(op, dtype)
],
)
def testAdvancedIndexing(self, name, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y,
unique_indices=True)
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker)
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in UpdateOps
for dtype in UpdateOps.dtypes(op)
for update_dtype in _compatible_dtypes(op, dtype)
],
)
def testAdvancedIndexingSorted(self, name, shape, dtype, update_shape,
update_dtype, indexer, op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.jax_fn(
op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True,
tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in UpdateOps
for dtype in UpdateOps.dtypes(op)
for update_dtype in _compatible_dtypes(op, dtype)
],
)
def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape,
update_dtype, indexer, op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker)
@jtu.sample_product(
[dict(name=name, mode=mode, shape=shape, indexer=indexer,
update_shape=update_shape)
for mode in [None] + MODES
for name, index_specs in (
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE]
for dtype in float_dtypes
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
],
)
def testStaticIndexingGrads(self, name, shape, dtype, update_shape,
update_dtype, indexer, op, mode):
rng = jtu.rand_default(self.rng())
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode,
unique_indices=True)
x = rng(shape, dtype)
y = rng(update_shape, update_dtype)
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=name, unique_indices=unique_indices, shape=shape,
indexer=indexer, update_shape=update_shape)
for name, index_specs in (
ADVANCED_INDEXING_TESTS_NO_REPEATS if unique_indices
else ADVANCED_INDEXING_TESTS)
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in (
[UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE]
if unique_indices
else [UpdateOps.ADD, UpdateOps.SUB])
for dtype in float_dtypes
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
],
)
for unique_indices in [False, True]
))
def testAdvancedIndexingGrads(self, name, shape, dtype, update_shape,
update_dtype, indexer, op, unique_indices):
rng = jtu.rand_default(self.rng())
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y,
unique_indices=unique_indices)
x = rng(shape, dtype)
y = rng(update_shape, update_dtype)
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
def testIndexMulGradFailsIfNotUnique(self):
y = jnp.ones((10,), jnp.int32)
f = lambda x, z: x.at[y].mul(z)
x = jnp.ones((100,), jnp.float32)
z = jnp.ones((10,), jnp.float32)
with self.assertRaises(NotImplementedError,
msg="scatter_mul gradients are only implemented if "
"`unique_indices=True`"):
jax.jvp(f, (x, z), (x, z))
def testSegmentSumBehavior(self):
# testAdvancedIndexing compares against NumPy, and as a result doesn't check
# repeated indices. This test is just a simple manual check, based on
# https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
data = np.array([5, 1, 7, 2, 3, 4, 1, 3], dtype=float)
segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3])
ans = jnp.zeros_like(data, shape=np.max(segment_ids) + 1).at[segment_ids].add(data)
expected = np.array([13, 2, 7, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
def testSegmentSum(self):
2020-12-08 13:03:30 -08:00
data = jnp.array([5, 1, 7, 2, 3, 4, 1, 3])
segment_ids = jnp.array([0, 0, 0, 1, 2, 2, 3, 3])
# test with explicit num_segments
ans = ops.segment_sum(data, segment_ids, num_segments=4)
2020-12-08 13:03:30 -08:00
expected = jnp.array([13, 2, 7, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
# test with explicit num_segments larger than the higher index.
ans = ops.segment_sum(data, segment_ids, num_segments=5)
2020-12-08 13:03:30 -08:00
expected = jnp.array([13, 2, 7, 4, 0])
self.assertAllClose(ans, expected, check_dtypes=False)
# test without explicit num_segments
ans = ops.segment_sum(data, segment_ids)
2020-12-08 13:03:30 -08:00
expected = jnp.array([13, 2, 7, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
# test with negative segment ids and segment ids larger than num_segments,
# that will be wrapped with the `mod`.
2020-12-08 13:03:30 -08:00
segment_ids = jnp.array([0, 4, 8, 1, 2, -6, -1, 3])
ans = ops.segment_sum(data, segment_ids, num_segments=4)
expected = jnp.array([5, 2, 3, 3])
self.assertAllClose(ans, expected, check_dtypes=False)
2023-12-13 07:45:52 +01:00
# test with negative segment ids and without explicit num_segments
# such as num_segments is defined by the smaller index.
2020-12-08 13:03:30 -08:00
segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
ans = ops.segment_sum(data, segment_ids)
expected = jnp.array([0, 0, 0, 13, 2, 7])
self.assertAllClose(ans, expected, check_dtypes=False)
def testSegmentSumOutOfBounds(self):
def fn(data, segment_ids):
return jax.ops.segment_sum(data, segment_ids, num_segments).sum()
data = np.array([0, 0], dtype=np.float32)
num_segments = 2
segment_ids = np.array([2, 3])
val, grad = jax.value_and_grad(fn)(data, segment_ids)
self.assertAllClose(val, np.array(0., np.float32))
self.assertAllClose(grad, np.array([0., 0.], np.float32))
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(reducer=reducer, op=op, identity=identity)],
dtype=[np.bool_],
shape=[(8,), (7, 4), (6, 4, 2)],
bucket_size=[None, 2],
num_segments=[None, 1, 3],
)
for reducer, op, identity in [
(ops.segment_min, np.minimum, True),
(ops.segment_max, np.maximum, False),
]))
def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity,
num_segments, bucket_size):
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]
if np.issubdtype(dtype, np.integer):
if np.isposinf(identity):
identity = np.iinfo(dtype).max
elif np.isneginf(identity):
identity = np.iinfo(dtype).min
jnp_fun = lambda data, segment_ids: reducer(
data, segment_ids, num_segments=num_segments, bucket_size=bucket_size)
def np_fun(data, segment_ids):
size = num_segments if num_segments is not None else (segment_ids.max() + 1)
out = np.full((size,) + shape[1:], identity, dtype)
for i, val in zip(segment_ids, data):
if 0 <= i < size:
out[i] = op(out[i], val).astype(dtype)
return out
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
if num_segments is not None:
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(reducer=reducer, op=op, identity=identity)],
dtype=default_dtypes,
shape=[(8,), (7, 4), (6, 4, 2)],
bucket_size=[None, 2],
num_segments=[None, 1, 3],
)
for reducer, op, identity in [
(ops.segment_sum, np.add, 0),
(ops.segment_prod, np.multiply, 1),
(ops.segment_min, np.minimum, float('inf')),
(ops.segment_max, np.maximum, -float('inf')),
]))
def testSegmentReduce(self, shape, dtype, reducer, op, identity, num_segments, bucket_size):
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]
if np.issubdtype(dtype, np.integer):
if np.isposinf(identity):
identity = np.iinfo(dtype).max
elif np.isneginf(identity):
identity = np.iinfo(dtype).min
jnp_fun = lambda data, segment_ids: reducer(
data, segment_ids, num_segments=num_segments, bucket_size=bucket_size)
def np_fun(data, segment_ids):
size = num_segments if num_segments is not None else (segment_ids.max() + 1)
out = np.full((size,) + shape[1:], identity, dtype)
for i, val in zip(segment_ids, data):
if 0 <= i < size:
out[i] = op(out[i], val).astype(dtype)
return out
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
if num_segments is not None:
self._CompileAndCheck(jnp_fun, args_maker)
def testIndexDtypeError(self):
# https://github.com/jax-ml/jax/issues/2795
jnp.array(1) # get rid of startup warning
2023-11-30 10:35:24 -08:00
with self.assertNoWarnings():
jnp.zeros(5).at[::2].set(1)
@jtu.sample_product(
[dict(idx=idx, idx_type=idx_type)
for idx, idx_type in [
([0], "array"),
([0, 0], "array"),
([[0, 0]], "tuple"),
([0, [0, 1]], "tuple"),
([0, np.arange(2)], "tuple"),
([0, None], "tuple"),
([0, slice(None)], "tuple"),
]
],
)
def testIndexSequenceDeprecation(self, idx, idx_type):
normalize = {"array": np.array, "tuple": tuple}[idx_type]
msg = {"array": ARRAY_MSG, "tuple": TUPLE_MSG}[idx_type]
x = jnp.arange(6).reshape(3, 2)
with self.assertRaisesRegex(TypeError, msg):
x[idx]
with self.assertNoWarnings():
x[normalize(idx)]
with self.assertRaisesRegex(TypeError, msg):
x.at[idx].set(0)
with self.assertNoWarnings():
x.at[normalize(idx)].set(0)
def testIndexedUpdateAliasingBug(self):
# https://github.com/jax-ml/jax/issues/7461
fn = lambda x: x.at[1:].set(1 + x[:-1])
y = jnp.zeros(8)
self.assertArraysEqual(fn(y), jax.jit(fn)(y))
2023-04-10 14:24:26 -07:00
def testScatterValuesCastToTargetDType(self):
# https://github.com/jax-ml/jax/issues/15505
2023-04-10 14:24:26 -07:00
a = jnp.zeros(1, dtype=jnp.uint32)
val = 2**32 - 1 # too large for int32
b = a.at[0].set(jnp.uint32(val))
self.assertEqual(int(b[0]), val)
c = a.at[0].set(val)
self.assertEqual(int(c[0]), val)
def testGradOfVmapOfScatter(self):
# Regression test for https://github.com/jax-ml/jax/issues/25878
def f(x, i):
return x.at[i].get(mode='clip')
x = jnp.array([1.0])
i = jnp.array([1]) # out-of-bound index
expected = jnp.array([[1.0]])
self.assertArraysEqual(jax.jacrev(f)(x, i), expected)
self.assertArraysEqual(jax.jacrev(jax.vmap(f, (None, 0)))(x, i), expected)
2018-11-17 18:03:33 -08:00
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())