2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2018-11-17 18:03:33 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
from __future__ import annotations
|
2018-11-21 13:27:26 -08:00
|
|
|
|
2019-02-22 07:55:36 -05:00
|
|
|
import enum
|
2018-11-17 18:03:33 -08:00
|
|
|
from functools import partial
|
|
|
|
import itertools
|
2021-08-11 17:32:36 -04:00
|
|
|
import typing
|
2023-12-11 13:59:29 +00:00
|
|
|
from typing import Any
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-08-03 12:19:36 -07:00
|
|
|
import jax
|
2022-04-20 16:04:12 -07:00
|
|
|
from jax import lax
|
2020-03-06 14:59:51 -05:00
|
|
|
from jax import numpy as jnp
|
2019-02-22 07:55:36 -05:00
|
|
|
from jax import ops
|
2022-03-09 18:18:16 -08:00
|
|
|
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src import config
|
2021-11-30 15:43:06 -08:00
|
|
|
from jax._src import dtypes
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2021-01-11 14:20:32 -08:00
|
|
|
from jax._src import util
|
2022-03-09 18:18:16 -08:00
|
|
|
from jax._src.lax import lax as lax_internal
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src.util import NumpyComplexWarning
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
config.parse_flags_with_absl()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# We disable the whitespace continuation check in this file because otherwise it
|
|
|
|
# makes the test name formatting unwieldy.
|
|
|
|
# pylint: disable=bad-continuation
|
|
|
|
|
|
|
|
|
2020-11-18 16:13:34 -08:00
|
|
|
ARRAY_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[array\(seq\)\]"
|
|
|
|
TUPLE_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[tuple\(seq\)\]"
|
|
|
|
|
|
|
|
|
2020-07-08 13:21:48 -07:00
|
|
|
float_dtypes = jtu.dtypes.floating
|
|
|
|
default_dtypes = float_dtypes + jtu.dtypes.integer
|
|
|
|
all_dtypes = default_dtypes + jtu.dtypes.boolean
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-08-11 17:32:36 -04:00
|
|
|
class IndexSpec(typing.NamedTuple):
|
2023-06-23 15:11:37 -07:00
|
|
|
shape: tuple[int, ...]
|
2021-08-11 17:32:36 -04:00
|
|
|
indexer: Any
|
2023-12-11 13:59:29 +00:00
|
|
|
out_shape: tuple[int, ...] | None = None
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
|
|
|
|
# TODO(mattjj,dougalm): add higher-order check
|
2023-10-12 13:15:22 +01:00
|
|
|
default_tol = 1e-6 if config.enable_x64.value else 1e-2
|
2018-11-17 18:03:33 -08:00
|
|
|
atol = atol or default_tol
|
|
|
|
rtol = rtol or default_tol
|
2023-01-25 12:15:00 -08:00
|
|
|
eps = eps or default_tol
|
2021-09-13 16:00:22 -04:00
|
|
|
jtu.check_jvp(f, partial(jax.jvp, f), args, atol, rtol, eps)
|
|
|
|
jtu.check_vjp(f, partial(jax.vjp, f), args, atol, rtol, eps)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-22 07:55:36 -05:00
|
|
|
STATIC_INDEXING_TESTS = [
|
2021-08-11 17:32:36 -04:00
|
|
|
("OneIntIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=1, out_shape=()),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=0, out_shape=(3,)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=2, out_shape=(4, 5)),
|
|
|
|
IndexSpec(shape=(3,), indexer=-1, out_shape=()),
|
|
|
|
IndexSpec(shape=(3,), indexer=-2, out_shape=()),
|
|
|
|
]),
|
|
|
|
("TwoIntIndices", [
|
|
|
|
IndexSpec(shape=(3, 3), indexer=(2, 1), out_shape=()),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(1, 2), out_shape=(5,)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(-1, 2), out_shape=(5,)),
|
|
|
|
]),
|
|
|
|
("ThreeIntIndices", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(1, 2, 3), out_shape=()),
|
|
|
|
]),
|
|
|
|
("OneSliceIndex", [
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(1, 3), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(1, -1), out_shape=(8,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(None, -1), out_shape=(9,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(None, None, None), out_shape=(10,)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(1, 3), out_shape=(2, 8)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(1, None), out_shape=(9, 8)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(None, 3), out_shape=(3, 8)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(-3, None), out_shape=(3, 8)),
|
|
|
|
]),
|
|
|
|
("OneSliceIndexNegativeStride", [
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(3, 1, -1), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(1, 8, -1), out_shape=(0,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(None, None, -1), out_shape=(10,)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1), out_shape=(2, 8)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1), out_shape=(0, 8)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(None, None, -1), out_shape=(10, 8)),
|
|
|
|
]),
|
2023-12-02 10:08:50 +02:00
|
|
|
("SliceIndexClamping", [
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(2, 11, 1), out_shape=(8,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(11, 12, 1), out_shape=(0,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(-11, -2, 1), out_shape=(8,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(-2, -12, -1), out_shape=(9,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(12, -12, -1), out_shape=(10,)),
|
|
|
|
]),
|
2021-08-11 17:32:36 -04:00
|
|
|
("OneSliceIndexNonUnitStride", [
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(0, 8, 2), out_shape=(4,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(0, 8, 3), out_shape=(3,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(1, 3, 2), out_shape=(1,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(1, None, 2), out_shape=(5,)),
|
|
|
|
IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3), out_shape=(3, 8)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(None, None, 2), out_shape=(5, 8)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2), out_shape=(4, 8)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=slice(None, None, -2), out_shape=(5, 8)),
|
|
|
|
]),
|
|
|
|
("TwoSliceIndices", [
|
|
|
|
IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2)),
|
|
|
|
out_shape=(2, 2)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2)),
|
|
|
|
out_shape=(9, 2)),
|
|
|
|
IndexSpec(shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2)),
|
|
|
|
out_shape=(10, 2)),
|
|
|
|
IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2)),
|
|
|
|
out_shape=(2, 2, 3)),
|
|
|
|
IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None)),
|
|
|
|
out_shape=(2, 8, 3)),
|
|
|
|
IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2)),
|
|
|
|
out_shape=(9, 2, 3)),
|
|
|
|
]),
|
|
|
|
("OneColonIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=slice(None), out_shape=(3,)),
|
|
|
|
IndexSpec(shape=(3, 4), indexer=slice(None), out_shape=(3, 4)),
|
|
|
|
]),
|
|
|
|
("MultipleColonIndices", [
|
|
|
|
IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None)),
|
|
|
|
out_shape=(3, 4)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None)),
|
|
|
|
out_shape=(3, 4, 5)),
|
|
|
|
]),
|
|
|
|
("MixedSliceIndices", [
|
|
|
|
IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2)),
|
|
|
|
out_shape=(10, 2)),
|
|
|
|
IndexSpec(shape=(10, 4), indexer=(1, slice(None)),
|
|
|
|
out_shape=(4,)),
|
|
|
|
]),
|
|
|
|
("EllipsisIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=Ellipsis, out_shape=(3,)),
|
|
|
|
IndexSpec(shape=(3, 4), indexer=Ellipsis, out_shape=(3, 4)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis), out_shape=(4, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3), out_shape=(3,)),
|
|
|
|
]),
|
|
|
|
("NoneIndex", [
|
|
|
|
IndexSpec(shape=(), indexer=None, out_shape=(1,)),
|
|
|
|
IndexSpec(shape=(), indexer=(None, None), out_shape=(1, 1)),
|
|
|
|
IndexSpec(shape=(), indexer=(Ellipsis, None), out_shape=(1,)),
|
|
|
|
IndexSpec(shape=(3,), indexer=None, out_shape=(1, 3)),
|
|
|
|
IndexSpec(shape=(3, 4), indexer=None, out_shape=(1, 3, 4)),
|
|
|
|
IndexSpec(shape=(3, 4), indexer=(Ellipsis, None), out_shape=(3, 4, 1)),
|
|
|
|
IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis), out_shape=(1, 4)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis), out_shape=(1, 4, 5)),
|
|
|
|
]),
|
|
|
|
("EmptyIndex", [
|
|
|
|
IndexSpec(shape=(), indexer=(), out_shape=()),
|
|
|
|
IndexSpec(shape=(3,), indexer=(), out_shape=(3,)),
|
|
|
|
IndexSpec(shape=(3, 4), indexer=(), out_shape=(3, 4)),
|
|
|
|
]),
|
|
|
|
("TupleOfIntAndSliceAndIntArray", [
|
|
|
|
IndexSpec(shape=(3, 2, 3), indexer=(0, slice(None), np.arange(3)),
|
|
|
|
out_shape=(3, 2)),
|
|
|
|
IndexSpec(shape=(3, 2, 3), indexer=(np.int32(1), slice(None), np.arange(3)),
|
|
|
|
out_shape=(3, 2)),
|
|
|
|
IndexSpec(shape=(3, 2, 3), indexer=(np.array(2), slice(None), np.arange(3)),
|
|
|
|
out_shape=(3, 2)),
|
|
|
|
]),
|
|
|
|
]
|
|
|
|
|
|
|
|
STATIC_INDEXING_OUT_OF_BOUNDS_TESTS = [
|
|
|
|
("OneIntIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=-4, out_shape=()),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=3, out_shape=(3,)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=4, out_shape=(4, 5)),
|
|
|
|
]),
|
|
|
|
("TwoIntIndices", [
|
|
|
|
IndexSpec(shape=(3, 3), indexer=(2, -4), out_shape=()),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(3, 2), out_shape=()),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(-4, 4), out_shape=(5,)),
|
|
|
|
]),
|
2019-02-22 07:55:36 -05:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2019-05-03 17:47:40 -07:00
|
|
|
ADVANCED_INDEXING_TESTS = [
|
2021-08-11 17:32:36 -04:00
|
|
|
("One1DIntArrayIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1]), out_shape=(3, 3)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1]),
|
|
|
|
out_shape=(4, 4, 5)),
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32),
|
|
|
|
out_shape=(0,)),
|
|
|
|
]),
|
|
|
|
("One2DIntArrayIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([[0, 0]]),out_shape=(1, 2)),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1], [0, 1, -1]]),
|
|
|
|
out_shape=(2, 3, 3)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1], [-1, -2, 1, 0]]),
|
|
|
|
out_shape=(2, 4, 4, 5)),
|
|
|
|
]),
|
|
|
|
("Two1DIntArrayIndicesNoBroadcasting", [
|
|
|
|
IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])),
|
|
|
|
out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([0, 2, 0, 1]), np.array([-1, 0, -1, 2])),
|
|
|
|
out_shape=(4, 5)),
|
|
|
|
]),
|
|
|
|
("Two1DIntArrayIndicesWithBroadcasting", [
|
|
|
|
IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])),
|
|
|
|
out_shape=(1, 2)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([[0, 2, 0, 1]]), np.array([-1, 0, -1, 2])),
|
|
|
|
out_shape=(1, 4, 5)),
|
|
|
|
]),
|
|
|
|
("ArrayOfInts", [
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([0, 1, 0]), out_shape=(3,)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=np.array([ 0, -1]), out_shape=(2, 4, 5)),
|
|
|
|
]),
|
|
|
|
("TupleOfListsOfPythonInts", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]]),
|
|
|
|
out_shape=(2, 4, 5)),
|
|
|
|
]),
|
|
|
|
("TupleOfPythonIntsAndIntArrays", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0, 3]])),
|
|
|
|
out_shape=(1, 4)),
|
|
|
|
]),
|
|
|
|
("TupleOfListsOfPythonIntsAndIntArrays", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])),
|
|
|
|
out_shape=(2, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0, 3]])),
|
|
|
|
out_shape=(2, 4, 5)),
|
|
|
|
]),
|
2019-05-03 17:47:40 -07:00
|
|
|
]
|
|
|
|
|
|
|
|
ADVANCED_INDEXING_TESTS_NO_REPEATS = [
|
2021-08-11 17:32:36 -04:00
|
|
|
("One1DIntArrayIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 0]), out_shape=(3, 3)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 1]),
|
|
|
|
out_shape=(3, 4, 5)),
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)),
|
|
|
|
]),
|
|
|
|
("One2DIntArrayIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)),
|
|
|
|
IndexSpec(shape=(6, 6), indexer=np.array([[1, 2, 0], [3, 4, -1]]),
|
|
|
|
out_shape=(2, 3, 6)),
|
|
|
|
]),
|
|
|
|
("Two1DIntArrayIndicesNoBroadcasting", [
|
|
|
|
IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])),
|
|
|
|
out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(4, 5, 6),
|
|
|
|
indexer=(np.array([0, 2, 1, 3]), np.array([-1, 0, -2, 1])),
|
|
|
|
out_shape=(4, 6)),
|
|
|
|
]),
|
|
|
|
("Two1DIntArrayIndicesWithBroadcasting", [
|
|
|
|
IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])),
|
|
|
|
out_shape=(1, 2)),
|
|
|
|
IndexSpec(shape=(4, 5, 6),
|
|
|
|
indexer=(np.array([[0, 2, -1, 1]]), np.array([-1, 0, -2, 2])),
|
|
|
|
out_shape=(1, 4, 6)),
|
|
|
|
]),
|
|
|
|
("ArrayOfInts", [
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([0, 2, 1]), out_shape=(3,)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=np.array([ 0, -1]), out_shape=(2, 4, 5)),
|
|
|
|
]),
|
|
|
|
("TupleOfListsOfPythonInts", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]]),
|
|
|
|
out_shape=(2, 3, 5)),
|
|
|
|
]),
|
|
|
|
("TupleOfPythonIntsAndIntArrays", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0]])),
|
|
|
|
out_shape=(1, 3)),
|
|
|
|
]),
|
|
|
|
("TupleOfListsOfPythonIntsAndIntArrays", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])),
|
|
|
|
out_shape=(2, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0]])),
|
|
|
|
out_shape=(2, 3, 5)),
|
|
|
|
]),
|
2019-05-03 17:47:40 -07:00
|
|
|
]
|
|
|
|
|
2020-07-21 23:16:27 -07:00
|
|
|
ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED = [
|
2021-08-11 17:32:36 -04:00
|
|
|
("One1DIntArrayIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=np.array([0, 1, 2]), out_shape=(3, 3)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 1, 2]),
|
|
|
|
out_shape=(3, 4, 5)),
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)),
|
|
|
|
]),
|
|
|
|
("One2DIntArrayIndex", [
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)),
|
|
|
|
IndexSpec(shape=(6, 6), indexer=np.array([[-1, 0, 1],
|
|
|
|
[ 2, 3, 4]]), out_shape=(2, 3, 6)),
|
|
|
|
]),
|
|
|
|
("Two1DIntArrayIndicesNoBroadcasting", [
|
|
|
|
IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])),
|
|
|
|
out_shape=(2,)),
|
|
|
|
IndexSpec(shape=(4, 5, 6),
|
|
|
|
indexer=(np.array([0, 1, 2, 3]), np.array([-2, -1, 0, 1])),
|
|
|
|
out_shape=(4, 6)),
|
|
|
|
]),
|
|
|
|
("Two1DIntArrayIndicesWithBroadcasting", [
|
|
|
|
IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])),
|
|
|
|
out_shape=(1, 2)),
|
|
|
|
IndexSpec(shape=(4, 5, 6),
|
|
|
|
indexer=(np.array([[-1, 0, 1, 2]]), np.array([-2, -1, 0, 2])),
|
|
|
|
out_shape=(1, 4, 6)),
|
|
|
|
]),
|
|
|
|
("TupleOfListsOfPythonInts", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[0, 2, 3]]),
|
|
|
|
out_shape=(2, 3, 5)),
|
|
|
|
]),
|
|
|
|
("TupleOfPythonIntsAndIntArrays", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[0, 2, 3]])),
|
|
|
|
out_shape=(1, 3)),
|
|
|
|
]),
|
|
|
|
("TupleOfListsOfPythonIntsAndIntArrays", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])),
|
|
|
|
out_shape=(2, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[0, 2, 3]])),
|
|
|
|
out_shape=(2, 3, 5)),
|
|
|
|
]),
|
2020-07-21 23:16:27 -07:00
|
|
|
]
|
|
|
|
|
2021-08-11 17:32:36 -04:00
|
|
|
|
2019-07-14 10:57:41 -04:00
|
|
|
MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [
|
2021-08-11 17:32:36 -04:00
|
|
|
("SlicesAndOneIntArrayIndex", [
|
|
|
|
IndexSpec(shape=(2, 3), indexer=(np.array([0, 1]), slice(1, 2)),
|
|
|
|
out_shape=(2, 1)),
|
|
|
|
IndexSpec(shape=(2, 3), indexer=(slice(0, 2), np.array([0, 2])),
|
|
|
|
out_shape=(2, 2)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(Ellipsis, np.array([0, 2]), slice(None)),
|
|
|
|
out_shape=(3, 2, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(Ellipsis, np.array([[0, 2], [1, 3]]), slice(None)),
|
|
|
|
out_shape=(3, 2, 2, 5)),
|
|
|
|
]),
|
|
|
|
("SlicesAndTwoIntArrayIndices", [
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(Ellipsis, np.array([0, 2]), np.array([-1, 2])),
|
|
|
|
out_shape=(3, 2)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([0, 2]), Ellipsis, np.array([-1, 2])),
|
|
|
|
out_shape=(2, 4)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([0, 2]), np.array([-1, 2]), Ellipsis),
|
|
|
|
out_shape=(2, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([0, 2]), np.array([-1, 2]), slice(1, 3)),
|
|
|
|
out_shape=(2, 2)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([0, 2]), slice(1, 3), np.array([-1, 2])),
|
|
|
|
out_shape=(2, 2)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([ 0, 2, -2]), slice(None, None, 2),
|
|
|
|
np.array([-1, 2, 1])),
|
|
|
|
out_shape=(3, 2)),
|
|
|
|
]),
|
|
|
|
("NonesAndIntArrayIndices", [
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([0, 2]), None, np.array([-1, 2])),
|
|
|
|
out_shape=(2, 1, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([0, 2]), None, None, np.array([-1, 2])),
|
|
|
|
out_shape=(2, 1, 1, 5)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(Ellipsis, np.array([0, 2]), None, None,
|
|
|
|
np.array([-1, 2])),
|
|
|
|
out_shape=(2, 3, 1, 1)),
|
|
|
|
]),
|
|
|
|
("IntArrayWithInt32Type", [
|
|
|
|
IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)),
|
|
|
|
out_shape=(3,)),
|
|
|
|
]),
|
2024-12-03 17:20:40 -08:00
|
|
|
("EllipsisWithArrayIndices", [
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 1]), ..., np.array([0, 1])),
|
|
|
|
out_shape=(2, 4)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), np.array([0, 1]), ..., np.array([0, 1])),
|
|
|
|
out_shape=(2, 3)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), ..., np.array([0, 1]), np.array([0, 1])),
|
|
|
|
out_shape=(3, 2)),
|
|
|
|
]),
|
2019-05-03 17:47:40 -07:00
|
|
|
]
|
|
|
|
|
2021-08-11 17:32:36 -04:00
|
|
|
|
2019-07-14 10:57:41 -04:00
|
|
|
MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [
|
2021-08-11 17:32:36 -04:00
|
|
|
("SlicesAndOneIntArrayIndex", [
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(Ellipsis, np.array([[0, 2], [1, 1]]), slice(None)),
|
|
|
|
out_shape=(3, 2, 2, 5)),
|
|
|
|
]),
|
|
|
|
("SlicesAndTwoIntArrayIndices", [
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([ 0, 2, -2]), slice(None, None, 2),
|
|
|
|
np.array([-1, 2, -1])),
|
|
|
|
out_shape=(3, 2)),
|
|
|
|
IndexSpec(shape=(3, 4, 5),
|
|
|
|
indexer=(np.array([[0, 2], [2, 0]]), Ellipsis,
|
|
|
|
np.array([[1, 0], [1, 0]])),
|
|
|
|
out_shape=(2, 2, 4)),
|
|
|
|
]),
|
|
|
|
]
|
|
|
|
|
|
|
|
MODES = ["clip", "drop", "promise_in_bounds"]
|
|
|
|
|
2019-07-14 10:57:41 -04:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class IndexingTest(jtu.JaxTestCase):
|
|
|
|
"""Tests for Numpy indexing translation rules."""
|
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer)
|
|
|
|
for name, index_specs in STATIC_INDEXING_TESTS
|
|
|
|
for shape, indexer, _ in index_specs],
|
|
|
|
dtype=all_dtypes
|
|
|
|
)
|
|
|
|
def testStaticIndexing(self, name, shape, dtype, indexer):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-10-13 10:08:20 -07:00
|
|
|
np_fun = lambda x: np.asarray(x)[indexer]
|
|
|
|
jnp_fun = lambda x: jnp.asarray(x)[indexer]
|
2020-11-18 16:13:34 -08:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2021-08-11 17:32:36 -04:00
|
|
|
# Tests x.at[...].get(...) as well.
|
2021-07-07 08:51:09 -04:00
|
|
|
jnp_fun = lambda x: jnp.asarray(x).at[indexer].get()
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2023-12-05 07:47:08 +01:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
|
|
|
def testStaticIndexingWithJaxArray(self):
|
|
|
|
shape = (10,)
|
|
|
|
indexer = slice(jnp.array(2, dtype=np.int32),
|
|
|
|
np.array(11, dtype=np.int32),
|
|
|
|
jnp.array(1, dtype=np.int32))
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, np.int32)]
|
|
|
|
np_fun = lambda x: np.asarray(x)[indexer]
|
|
|
|
jnp_fun = lambda x: jnp.asarray(x)[indexer]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
# Tests x.at[...].get(...) as well.
|
|
|
|
jnp_fun = lambda x: jnp.asarray(x).at[indexer].get()
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2021-07-07 08:51:09 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
funcname=["negative", "sin", "cos", "square", "sqrt", "log", "exp"],
|
|
|
|
)
|
2022-02-18 09:44:40 -08:00
|
|
|
def testIndexApply(self, funcname, size=10, dtype='float32'):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
idx_rng = jtu.rand_int(self.rng(), -size, size)
|
|
|
|
np_func = getattr(np, funcname)
|
|
|
|
jnp_func = getattr(jnp, funcname)
|
|
|
|
@jtu.ignore_warning(category=RuntimeWarning)
|
|
|
|
def np_op(x, idx):
|
|
|
|
y = x.copy()
|
|
|
|
np_func.at(y, idx)
|
|
|
|
return y
|
|
|
|
def jnp_op(x, idx):
|
|
|
|
return jnp.asarray(x).at[idx].apply(jnp_func)
|
2023-04-17 15:27:03 -07:00
|
|
|
|
|
|
|
# Test with traced integer index
|
2022-02-18 09:44:40 -08:00
|
|
|
args_maker = lambda: [rng(size, dtype), idx_rng(size, int)]
|
2023-10-24 10:44:52 -07:00
|
|
|
tol = (
|
2023-09-05 18:48:18 -07:00
|
|
|
5e-5
|
2023-09-27 12:10:06 -07:00
|
|
|
if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp")
|
2023-09-05 18:48:18 -07:00
|
|
|
else None
|
|
|
|
)
|
2023-10-24 10:44:52 -07:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=tol)
|
2022-02-18 09:44:40 -08:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
|
|
|
|
2023-04-17 15:27:03 -07:00
|
|
|
# Test with slice index
|
|
|
|
idx = slice(1, 5)
|
|
|
|
np_op_idx = partial(np_op, idx=idx)
|
|
|
|
jnp_op_idx = partial(jnp_op, idx=idx)
|
|
|
|
args_maker = lambda: [rng(size, dtype)]
|
2023-10-24 10:44:52 -07:00
|
|
|
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=tol,
|
|
|
|
rtol=tol)
|
2023-04-17 15:27:03 -07:00
|
|
|
self._CompileAndCheck(jnp_op_idx, args_maker)
|
|
|
|
|
2023-07-10 16:42:45 -07:00
|
|
|
def testIndexApplyBatchingBug(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/16655
|
2023-07-10 16:42:45 -07:00
|
|
|
arr = jnp.array([[1, 2, 3, 4, 5, 6]])
|
|
|
|
ind = jnp.array([3])
|
|
|
|
func = lambda a, i: a.at[i].apply(lambda x: x - 1)
|
|
|
|
expected = jnp.array(list(map(func, arr, ind)))
|
|
|
|
out = jax.vmap(func)(arr, ind)
|
|
|
|
self.assertArraysEqual(out, expected)
|
|
|
|
|
2023-04-17 15:27:03 -07:00
|
|
|
def testIndexUpdateScalarBug(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/14923
|
2023-04-17 15:27:03 -07:00
|
|
|
a = jnp.arange(10.)
|
|
|
|
out = a.at[0].apply(jnp.cos)
|
|
|
|
self.assertArraysEqual(out, a.at[0].set(1))
|
2022-02-18 09:44:40 -08:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer, mode=mode)
|
|
|
|
for mode in MODES
|
|
|
|
for name, index_specs in (
|
|
|
|
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
|
|
|
|
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
|
|
|
|
for shape, indexer, _ in index_specs
|
|
|
|
],
|
|
|
|
dtype=float_dtypes,
|
|
|
|
)
|
|
|
|
def testStaticIndexingGrads(self, name, shape, dtype, indexer, mode):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-03-06 14:59:51 -05:00
|
|
|
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
|
2018-11-17 18:03:33 -08:00
|
|
|
arg = rng(shape, dtype)
|
2021-08-11 17:32:36 -04:00
|
|
|
# Use an arbitrary finite fill_value, since NaNs won't work in a numerical
|
|
|
|
# gradient test.
|
|
|
|
fun = lambda x: jnp.asarray(x).at[indexer].get(mode=mode, fill_value=7)**2
|
2018-11-17 18:03:33 -08:00
|
|
|
check_grads(fun, (arg,), 2, tol, tol, tol)
|
|
|
|
|
|
|
|
def _ReplaceSlicesWithTuples(self, idx):
|
|
|
|
"""Helper method to replace slices with tuples for dynamic indexing args."""
|
|
|
|
if isinstance(idx, slice):
|
|
|
|
triple = idx.start, idx.stop, idx.step
|
|
|
|
isnone = [i for i, elt in enumerate(triple) if elt is None]
|
|
|
|
zeros = itertools.repeat(0)
|
|
|
|
nones = itertools.repeat(None)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
out = util.subvals(triple, zip(isnone, zeros))
|
|
|
|
return out, lambda out: slice(*util.subvals(out, zip(isnone, nones)))
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(idx, (tuple, list)) and idx:
|
|
|
|
t = type(idx)
|
|
|
|
elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
|
|
|
|
return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts)))
|
|
|
|
else:
|
|
|
|
return idx, lambda x: x
|
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer)
|
2018-11-17 18:03:33 -08:00
|
|
|
for name, index_specs in [
|
2022-10-06 15:53:18 +00:00
|
|
|
("OneSliceIndex",
|
|
|
|
[IndexSpec(shape=(5,), indexer=slice(1, 3)),
|
|
|
|
IndexSpec(shape=(5, 4), indexer=slice(1, 3))]),
|
|
|
|
("TwoSliceIndices",
|
|
|
|
[IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
|
|
|
|
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]),
|
|
|
|
("NonUnitStrides", [
|
|
|
|
IndexSpec(shape=(3,), indexer=slice(None, None, -1)),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
|
|
|
|
]),
|
|
|
|
("OnlyStartOrStopDynamic", [
|
|
|
|
IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
|
|
|
|
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
|
|
|
|
]),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
2021-08-11 17:32:36 -04:00
|
|
|
for shape, indexer, _ in index_specs
|
2022-10-06 15:53:18 +00:00
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
|
|
|
def testDynamicIndexingWithSlicesErrors(self, name, shape, dtype, indexer):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
@jax.jit
|
2018-11-17 18:03:33 -08:00
|
|
|
def fun(x, unpacked_indexer):
|
|
|
|
indexer = pack_indexer(unpacked_indexer)
|
|
|
|
return x[indexer]
|
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
|
|
|
|
self.assertRaises(IndexError, lambda: fun(*args_maker()))
|
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer)
|
2018-11-17 18:03:33 -08:00
|
|
|
for name, index_specs in [
|
|
|
|
("OneIntIndex",
|
|
|
|
[IndexSpec(shape=(3,), indexer=1),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=0),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=2),
|
|
|
|
IndexSpec(shape=(3,), indexer=-1),
|
|
|
|
IndexSpec(shape=(3,), indexer=-2)]),
|
|
|
|
("TwoIntIndices",
|
|
|
|
[IndexSpec(shape=(3, 3), indexer=(2, 1)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]),
|
|
|
|
("ThreeIntIndices",
|
|
|
|
[IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
|
|
|
|
]
|
2021-08-11 17:32:36 -04:00
|
|
|
for shape, indexer, _ in index_specs
|
2022-10-06 15:53:18 +00:00
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
|
|
|
def testDynamicIndexingWithIntegers(self, name, shape, dtype, indexer):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
|
|
|
|
|
2020-10-13 10:08:20 -07:00
|
|
|
def np_fun(x, unpacked_indexer):
|
2018-11-17 18:03:33 -08:00
|
|
|
indexer = pack_indexer(unpacked_indexer)
|
2020-10-13 10:08:20 -07:00
|
|
|
return np.asarray(x)[indexer]
|
|
|
|
|
|
|
|
def jnp_fun(x, unpacked_indexer):
|
|
|
|
indexer = pack_indexer(unpacked_indexer)
|
|
|
|
return jnp.array(x)[indexer]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
|
2020-11-18 16:13:34 -08:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer)
|
2018-11-17 18:03:33 -08:00
|
|
|
for name, index_specs in [
|
|
|
|
("OneIntIndex",
|
|
|
|
[IndexSpec(shape=(3,), indexer=1),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=0),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=2),
|
|
|
|
IndexSpec(shape=(3,), indexer=-1),
|
|
|
|
IndexSpec(shape=(3,), indexer=-2),
|
|
|
|
]),
|
|
|
|
("TwoIntIndices",
|
|
|
|
[IndexSpec(shape=(3, 3), indexer=(2, 1)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)),
|
|
|
|
]),
|
|
|
|
("ThreeIntIndices",
|
|
|
|
[IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
|
|
|
|
]
|
2021-08-11 17:32:36 -04:00
|
|
|
for shape, indexer, _ in index_specs
|
2022-10-06 15:53:18 +00:00
|
|
|
],
|
|
|
|
dtype=float_dtypes,
|
|
|
|
)
|
|
|
|
def testDynamicIndexingWithIntegersGrads(self, name, shape, dtype, indexer):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-03-06 14:59:51 -05:00
|
|
|
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
|
2018-11-17 18:03:33 -08:00
|
|
|
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
@jax.jit
|
2018-11-17 18:03:33 -08:00
|
|
|
def fun(unpacked_indexer, x):
|
|
|
|
indexer = pack_indexer(unpacked_indexer)
|
|
|
|
return x[indexer]
|
|
|
|
|
|
|
|
arr = rng(shape, dtype)
|
|
|
|
check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol)
|
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer)
|
2019-05-03 17:47:40 -07:00
|
|
|
for name, index_specs in ADVANCED_INDEXING_TESTS
|
2021-08-11 17:32:36 -04:00
|
|
|
for shape, indexer, _ in index_specs
|
2022-10-06 15:53:18 +00:00
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
|
|
|
def testAdvancedIntegerIndexing(self, name, shape, dtype, indexer):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype), indexer]
|
2020-10-13 10:08:20 -07:00
|
|
|
np_fun = lambda x, idx: np.asarray(x)[idx]
|
|
|
|
jnp_fun = lambda x, idx: jnp.asarray(x)[idx]
|
2020-11-18 16:13:34 -08:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(dtype=jtu.dtypes.unsigned + jtu.dtypes.integer)
|
2022-04-20 16:04:12 -07:00
|
|
|
def testIndicesNormalizationByType(self, dtype):
|
|
|
|
x = jnp.arange(10)
|
|
|
|
jaxpr = jax.make_jaxpr(x.__getitem__)(jnp.arange(3, dtype=dtype))
|
|
|
|
primitives = [eqn.primitive for eqn in jaxpr.eqns]
|
|
|
|
if np.issubdtype(dtype, np.unsignedinteger):
|
|
|
|
# Unsigned integers should not require lt, add, and select.
|
|
|
|
self.assertEqual(primitives, [lax.convert_element_type_p, lax.broadcast_in_dim_p, lax.gather_p])
|
|
|
|
else:
|
|
|
|
# May or may not contain convert_element_type.
|
|
|
|
self.assertIn(len(primitives), [5, 6])
|
|
|
|
self.assertEqual(primitives[:3], [lax.lt_p, lax.add_p, lax.select_n_p])
|
|
|
|
self.assertEqual(primitives[-2:], [lax.broadcast_in_dim_p, lax.gather_p])
|
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer)
|
2018-11-17 18:03:33 -08:00
|
|
|
for name, index_specs in [
|
|
|
|
("One1DIntArrayIndex",
|
2020-07-14 13:03:24 -07:00
|
|
|
[IndexSpec(shape=(3,), indexer=np.array([0, 1])),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1])),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1])),
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([-1, 1])),
|
|
|
|
IndexSpec(shape=(3,), indexer=np.array([-2, -1])),
|
2018-11-17 18:03:33 -08:00
|
|
|
]),
|
|
|
|
("One2DIntArrayIndex",
|
2020-07-14 13:03:24 -07:00
|
|
|
[IndexSpec(shape=(3,), indexer=np.array([[0, 0]])),
|
|
|
|
IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1],
|
2018-11-17 18:03:33 -08:00
|
|
|
[0, 1, -1]])),
|
2020-07-14 13:03:24 -07:00
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1],
|
2018-11-17 18:03:33 -08:00
|
|
|
[-1, -2, 1, 0]])),
|
|
|
|
]),
|
|
|
|
("Two1DIntArrayIndicesNoBroadcasting",
|
2020-11-18 16:13:34 -08:00
|
|
|
[IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]),
|
|
|
|
np.array([1, 2]))),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2, 0, 1]),
|
|
|
|
np.array([-1, 0, -1, 2]))),
|
2018-11-17 18:03:33 -08:00
|
|
|
]),
|
|
|
|
("Two1DIntArrayIndicesWithBroadcasting",
|
2020-11-18 16:13:34 -08:00
|
|
|
[IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]),
|
|
|
|
np.array([1, 2]))),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(np.array([[0, 2, 0, 1]]),
|
|
|
|
np.array([-1, 0, -1, 2]))),
|
2018-11-17 18:03:33 -08:00
|
|
|
]),
|
2020-11-18 16:13:34 -08:00
|
|
|
("TupleOfPythonIntsAndIntArrays",
|
|
|
|
[IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1]))),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=(0, 1,
|
|
|
|
np.array([[2, 3, 0, 3]]))),
|
2018-11-17 18:03:33 -08:00
|
|
|
]),
|
2020-11-18 16:13:34 -08:00
|
|
|
("TupleOfListsOfPythonIntsAndIntArrays",
|
|
|
|
[IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0]))),
|
|
|
|
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]],
|
|
|
|
np.array([[2, 3, 0, 3]]))),
|
2018-11-17 18:03:33 -08:00
|
|
|
]),
|
|
|
|
]
|
2021-08-11 17:32:36 -04:00
|
|
|
for shape, indexer, _ in index_specs
|
2022-10-06 15:53:18 +00:00
|
|
|
],
|
|
|
|
dtype=float_dtypes,
|
|
|
|
)
|
|
|
|
def testAdvancedIntegerIndexingGrads(self, name, shape, dtype, indexer):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-03-06 14:59:51 -05:00
|
|
|
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
|
2018-11-17 18:03:33 -08:00
|
|
|
arg = rng(shape, dtype)
|
2020-05-04 23:00:20 -04:00
|
|
|
fun = lambda x: jnp.asarray(x)[indexer]
|
2020-11-18 16:13:34 -08:00
|
|
|
check_grads(fun, (arg,), 2, tol, tol, eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer)
|
2019-05-03 17:47:40 -07:00
|
|
|
for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
|
2021-08-11 17:32:36 -04:00
|
|
|
for shape, indexer, _ in index_specs
|
2022-10-06 15:53:18 +00:00
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
|
|
|
def testMixedAdvancedIntegerIndexing(self, name, shape, dtype, indexer):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
indexer_with_dummies = [e if isinstance(e, np.ndarray) else ()
|
2018-11-17 18:03:33 -08:00
|
|
|
for e in indexer]
|
|
|
|
substitutes = [(i, e) for i, e in enumerate(indexer)
|
2020-07-14 13:03:24 -07:00
|
|
|
if not isinstance(e, np.ndarray)]
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype), indexer_with_dummies]
|
|
|
|
|
2020-10-13 10:08:20 -07:00
|
|
|
def jnp_fun(x, indexer_with_dummies):
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
|
2020-04-12 15:35:35 -04:00
|
|
|
return jnp.asarray(x)[idx]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-10-13 10:08:20 -07:00
|
|
|
def np_fun(x, indexer_with_dummies):
|
|
|
|
idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
|
|
|
|
return np.asarray(x)[idx]
|
|
|
|
|
2020-11-18 16:13:34 -08:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def testAdvancedIndexingManually(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
x = self.rng().randn(3, 4, 5)
|
2020-07-14 13:03:24 -07:00
|
|
|
index_array = np.array([0, 2, -1, 0])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
op = lambda x, index_array: x[..., index_array, :]
|
2021-09-13 16:00:22 -04:00
|
|
|
cop = jax.jit(op)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
a1 = op(x, index_array)
|
|
|
|
a2 = cop(x, index_array)
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(a1, a2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
op = lambda x, index_array: x[..., index_array, :, index_array, None]
|
2021-09-13 16:00:22 -04:00
|
|
|
cop = jax.jit(op)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
a1 = op(x, index_array)
|
|
|
|
a2 = cop(x, index_array)
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(a1, a2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
op = lambda x, index_array: x[index_array, ..., index_array[:, None], None]
|
2021-09-13 16:00:22 -04:00
|
|
|
cop = jax.jit(op)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
a1 = op(x, index_array)
|
|
|
|
a2 = cop(x, index_array)
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(a1, a2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def testUnpacking(self):
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
a, b, c = x
|
|
|
|
return a + b + c
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
cfoo = jax.jit(foo)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
a1 = foo(np.arange(3))
|
|
|
|
a2 = cfoo(np.arange(3))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(a1, a2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-23 11:02:00 -08:00
|
|
|
def testBooleanIndexingArray1D(self):
|
2020-07-14 13:03:24 -07:00
|
|
|
idx = np.array([True, True, False])
|
2021-09-13 16:00:22 -04:00
|
|
|
x = jax.device_put(np.arange(3))
|
2018-12-23 11:02:00 -08:00
|
|
|
ans = x[idx]
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.arange(3)[idx]
|
2018-12-23 11:02:00 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testBooleanIndexingList1D(self):
|
|
|
|
idx = [True, True, False]
|
2021-09-13 16:00:22 -04:00
|
|
|
x = jax.device_put(np.arange(3))
|
2020-11-18 16:13:34 -08:00
|
|
|
with self.assertRaisesRegex(TypeError, ARRAY_MSG):
|
|
|
|
x[idx]
|
2018-12-23 11:02:00 -08:00
|
|
|
|
|
|
|
def testBooleanIndexingArray2DBroadcast(self):
|
2020-07-14 13:03:24 -07:00
|
|
|
idx = np.array([True, True, False, True])
|
|
|
|
x = np.arange(8).reshape(4, 2)
|
2021-09-13 16:00:22 -04:00
|
|
|
ans = jax.device_put(x)[idx]
|
2018-12-23 11:02:00 -08:00
|
|
|
expected = x[idx]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testBooleanIndexingList2DBroadcast(self):
|
|
|
|
idx = [True, True, False, True]
|
2020-07-14 13:03:24 -07:00
|
|
|
x = np.arange(8).reshape(4, 2)
|
2020-11-18 16:13:34 -08:00
|
|
|
with self.assertRaisesRegex(TypeError, ARRAY_MSG):
|
2021-09-13 16:00:22 -04:00
|
|
|
jax.device_put(x)[idx]
|
2018-12-23 11:02:00 -08:00
|
|
|
|
|
|
|
def testBooleanIndexingArray2D(self):
|
2020-07-14 13:03:24 -07:00
|
|
|
idx = np.array([[True, False],
|
2018-12-23 11:02:00 -08:00
|
|
|
[False, True],
|
|
|
|
[False, False],
|
|
|
|
[True, True]])
|
2020-07-14 13:03:24 -07:00
|
|
|
x = np.arange(8).reshape(4, 2)
|
2021-09-13 16:00:22 -04:00
|
|
|
ans = jax.device_put(x)[idx]
|
2018-12-23 11:02:00 -08:00
|
|
|
expected = x[idx]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-11-01 12:24:05 -07:00
|
|
|
def testBoolean1DIndexingWithEllipsis(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/8412
|
2021-11-01 12:24:05 -07:00
|
|
|
x = np.arange(24).reshape(4, 3, 2)
|
|
|
|
idx = (..., np.array([True, False]))
|
|
|
|
ans = jnp.array(x)[idx]
|
|
|
|
expected = x[idx]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-12-28 09:52:54 -08:00
|
|
|
def testBoolean1DIndexingWithEllipsis2(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/9050
|
2021-12-28 09:52:54 -08:00
|
|
|
x = np.arange(3)
|
|
|
|
idx = (..., np.array([True, False, True]))
|
|
|
|
ans = jnp.array(x)[idx]
|
|
|
|
expected = x[idx]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testBoolean1DIndexingWithEllipsis3(self):
|
|
|
|
x = np.arange(6).reshape(2, 3)
|
|
|
|
idx = (0, ..., np.array([True, False, True]))
|
|
|
|
ans = jnp.array(x)[idx]
|
|
|
|
expected = x[idx]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-11-01 12:24:05 -07:00
|
|
|
def testBoolean2DIndexingWithEllipsis(self):
|
|
|
|
x = np.arange(24).reshape(4, 3, 2)
|
|
|
|
idx = (..., np.array([[True, False], [True, False], [False, False]]))
|
|
|
|
ans = jnp.array(x)[idx]
|
|
|
|
expected = x[idx]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testBoolean1DIndexingWithTrailingEllipsis(self):
|
|
|
|
x = np.arange(24).reshape(4, 3, 2)
|
|
|
|
idx = (np.array([True, False, True, False]), ...)
|
|
|
|
ans = jnp.array(x)[idx]
|
|
|
|
expected = x[idx]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2018-12-23 11:02:00 -08:00
|
|
|
def testBooleanIndexingDynamicShapeError(self):
|
2020-07-14 13:03:24 -07:00
|
|
|
x = np.zeros(3)
|
|
|
|
i = np.array([True, True, False])
|
2021-09-13 16:00:22 -04:00
|
|
|
self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i))
|
2018-12-23 11:02:00 -08:00
|
|
|
|
2019-01-02 17:46:46 -08:00
|
|
|
def testIssue187(self):
|
2020-03-06 14:59:51 -05:00
|
|
|
x = jnp.ones((5, 5))
|
2019-01-02 17:46:46 -08:00
|
|
|
x[[0, 2, 4], [0, 2, 4]] # doesn't crash
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
x = np.arange(25).reshape((5, 5))
|
2021-09-13 16:00:22 -04:00
|
|
|
ans = jax.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
|
2019-01-02 17:46:46 -08:00
|
|
|
expected = x[[0, 2, 4], [0, 2, 4]]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-07-23 14:18:07 -04:00
|
|
|
def testJVPOfGradOfIndexing(self):
|
|
|
|
# Should return a value, even though we didn't pass a symbolic zero as the
|
|
|
|
# index tangent.
|
2020-03-06 14:59:51 -05:00
|
|
|
x = jnp.ones((3, 4), jnp.float32)
|
|
|
|
i = jnp.ones((3,), jnp.int32)
|
|
|
|
f = lambda x, i: jnp.sum(x[i])
|
2021-09-13 16:00:22 -04:00
|
|
|
primals, tangents = jax.jvp(jax.grad(f), (x, i),
|
2020-09-24 16:29:57 +01:00
|
|
|
(x, np.zeros(i.shape, dtypes.float0)))
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.broadcast_to(
|
|
|
|
np.array([0, 3, 0], dtype=np.float32)[:, None], (3, 4))
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(expected, primals)
|
2020-07-14 13:03:24 -07:00
|
|
|
self.assertAllClose(np.zeros_like(x), tangents)
|
2019-07-23 14:18:07 -04:00
|
|
|
|
2023-04-10 12:02:59 -07:00
|
|
|
def testSimpleIndexingUsesSlice(self):
|
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
|
2023-09-06 09:32:25 -07:00
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p)
|
2023-04-10 12:02:59 -07:00
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
|
2023-09-06 09:32:25 -07:00
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
|
2023-04-10 12:02:59 -07:00
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
|
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
|
2023-09-06 09:32:25 -07:00
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
|
2023-04-10 12:02:59 -07:00
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
|
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
|
2023-09-06 09:32:25 -07:00
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
|
2023-04-10 12:02:59 -07:00
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
|
|
|
|
|
2024-05-07 10:29:17 -07:00
|
|
|
# Indexing with `Ellipsis` is not lowered to `gather`.
|
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[..., 0])(jnp.ones((3, 4, 5)))
|
|
|
|
self.assertLen((jaxpr.jaxpr.eqns), 2)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
|
|
|
|
|
2023-04-10 12:02:59 -07:00
|
|
|
# Simple reverses lower to lax.rev_p
|
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4)))
|
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)
|
|
|
|
|
2023-09-06 09:32:25 -07:00
|
|
|
# Non-static indices produce a dynamic slice
|
|
|
|
jaxpr = jax.make_jaxpr(lambda x, i: x[i])(jnp.ones((4,)), 2)
|
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 6)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
|
|
|
|
|
2019-11-01 13:46:13 -07:00
|
|
|
def testTrivialGatherIsntGenerated(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/1621
|
2021-09-13 16:00:22 -04:00
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
|
2019-11-26 08:18:53 +01:00
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
2019-11-01 13:46:13 -07:00
|
|
|
self.assertNotIn('gather', str(jaxpr))
|
2019-07-23 14:18:07 -04:00
|
|
|
|
2022-04-10 00:18:10 +01:00
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
|
2022-08-25 15:14:01 -07:00
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
|
2023-04-10 12:02:59 -07:00
|
|
|
|
2022-04-10 00:18:10 +01:00
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
|
2022-08-25 15:14:01 -07:00
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
|
2022-04-10 00:18:10 +01:00
|
|
|
|
2022-05-09 23:58:42 +01:00
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4))
|
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)
|
|
|
|
|
2023-04-10 12:02:59 -07:00
|
|
|
def testOOBEmptySlice(self):
|
|
|
|
x = jnp.arange(4, dtype='float32')
|
|
|
|
self.assertArraysEqual(x[1:0], jnp.empty(0, dtype='float32'))
|
|
|
|
self.assertArraysEqual(x[-2:-10], jnp.empty(0, dtype='float32'))
|
|
|
|
self.assertArraysEqual(x[5:10], jnp.empty(0, dtype='float32'))
|
|
|
|
|
|
|
|
x = jnp.arange(6, dtype='float32').reshape(2, 3)
|
|
|
|
self.assertArraysEqual(x[-1:-4], jnp.empty((0, 3), dtype='float32'))
|
|
|
|
self.assertArraysEqual(x[:, 3:2], jnp.empty((2, 0), dtype='float32'))
|
|
|
|
|
2020-04-11 17:27:14 +02:00
|
|
|
def testIndexingEmptyDimension(self):
|
|
|
|
# Issue 2671: XLA error when indexing into dimension of size 0
|
|
|
|
x = jnp.ones((2, 0))
|
|
|
|
# The following work, even on axis 1 of size 0
|
2021-08-03 12:19:36 -07:00
|
|
|
with jax.numpy_rank_promotion('allow'):
|
|
|
|
_ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2]
|
2020-04-11 17:27:14 +02:00
|
|
|
|
|
|
|
with self.assertRaisesRegex(IndexError,
|
|
|
|
"index .* is out of bounds for axis .* with size 0"):
|
2020-07-14 13:03:24 -07:00
|
|
|
_ = np.ones((2, 0))[0, 0] # The numpy error
|
2020-04-11 17:27:14 +02:00
|
|
|
with self.assertRaisesRegex(IndexError,
|
|
|
|
"index is out of bounds for axis .* with size 0"):
|
|
|
|
_ = x[0, 0] # JAX indexing
|
|
|
|
with self.assertRaisesRegex(IndexError,
|
|
|
|
"index is out of bounds for axis .* with size 0"):
|
2021-09-13 16:00:22 -04:00
|
|
|
jax.jit(lambda i: x[0, i])(0) # JAX indexing under jit
|
2020-04-11 17:27:14 +02:00
|
|
|
|
2019-11-07 10:14:16 -08:00
|
|
|
def testBooleanIndexingWithEmptyResult(self):
|
|
|
|
# based on a TensorFlow Probability test that started failing after #1622
|
2020-03-06 14:59:51 -05:00
|
|
|
x = jnp.array([-1])
|
|
|
|
mask = jnp.array([False])
|
2019-11-07 10:14:16 -08:00
|
|
|
ans = x[mask] # doesn't crash
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.array([-1])[np.array([False])]
|
2019-11-07 10:14:16 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-08-03 09:51:52 -07:00
|
|
|
def testBooleanIndexingShapeMismatch(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/7329
|
2021-08-03 09:51:52 -07:00
|
|
|
x = jnp.arange(4)
|
|
|
|
idx = jnp.array([True, False])
|
|
|
|
with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"):
|
|
|
|
x[idx]
|
|
|
|
|
2023-11-15 09:03:15 -08:00
|
|
|
def testBooleanIndexingWithNone(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/18542
|
2023-11-15 09:03:15 -08:00
|
|
|
x = jnp.arange(6).reshape(2, 3)
|
|
|
|
idx = (None, jnp.array([True, False]))
|
|
|
|
ans = x[idx]
|
|
|
|
expected = jnp.arange(3).reshape(1, 1, 3)
|
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
|
|
|
def testBooleanIndexingWithNoneAndEllipsis(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/18542
|
2023-11-15 09:03:15 -08:00
|
|
|
x = jnp.arange(6).reshape(2, 3)
|
|
|
|
mask = jnp.array([True, False, False])
|
|
|
|
ans = x[None, ..., mask]
|
|
|
|
expected = jnp.array([0, 3]).reshape(1, 2, 1)
|
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
|
|
|
def testBooleanIndexingWithEllipsisAndNone(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/18542
|
2023-11-15 09:03:15 -08:00
|
|
|
x = jnp.arange(6).reshape(2, 3)
|
|
|
|
mask = jnp.array([True, False, False])
|
|
|
|
ans = x[..., None, mask]
|
|
|
|
expected = jnp.array([0, 3]).reshape(2, 1, 1)
|
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
2021-08-03 09:51:52 -07:00
|
|
|
def testNontrivialBooleanIndexing(self):
|
|
|
|
# Test nontrivial corner case in boolean indexing shape validation
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
index = (rng((2, 3), np.bool_), rng((6,), np.bool_))
|
|
|
|
|
|
|
|
args_maker = lambda: [rng((2, 3, 6), np.int32)]
|
|
|
|
np_fun = lambda x: np.asarray(x)[index]
|
|
|
|
jnp_fun = lambda x: jnp.asarray(x)[index]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2024-08-06 09:56:03 -07:00
|
|
|
@parameterized.parameters(
|
|
|
|
[(3,), (0,)],
|
|
|
|
[(3, 4), (0,)],
|
|
|
|
[(3, 4), (0, 4)],
|
|
|
|
[(3, 4), (3, 0)],
|
|
|
|
[(3, 4, 5), (3, 0)],
|
|
|
|
)
|
|
|
|
def testEmptyBooleanIndexing(self, x_shape, m_shape):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/22886
|
2024-08-06 09:56:03 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(x_shape, np.int32), np.empty(m_shape, dtype=bool)]
|
|
|
|
|
|
|
|
np_fun = lambda x, m: np.asarray(x)[np.asarray(m)]
|
|
|
|
jnp_fun = lambda x, m: jnp.asarray(x)[jnp.asarray(m)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2024-02-08 14:11:15 -08:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(2, 3, 4, 5)],
|
|
|
|
idx=[
|
|
|
|
np.index_exp[True],
|
|
|
|
np.index_exp[False],
|
|
|
|
np.index_exp[..., True],
|
|
|
|
np.index_exp[..., False],
|
|
|
|
np.index_exp[0, :2, True],
|
|
|
|
np.index_exp[0, :2, False],
|
|
|
|
np.index_exp[:2, 0, True],
|
|
|
|
np.index_exp[:2, 0, False],
|
|
|
|
np.index_exp[:2, np.array([0, 2]), True],
|
|
|
|
np.index_exp[np.array([1, 0]), :, True],
|
|
|
|
np.index_exp[True, :, True, :, np.array(True)],
|
|
|
|
]
|
|
|
|
)
|
|
|
|
def testScalarBooleanIndexing(self, shape, idx):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, np.int32)]
|
|
|
|
np_fun = lambda x: np.asarray(x)[idx]
|
|
|
|
jnp_fun = lambda x: jnp.asarray(x)[idx]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2024-05-20 05:33:36 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(2, 3, 4, 5)],
|
|
|
|
update_ndim=[0, 1, 2],
|
|
|
|
idx=[
|
|
|
|
np.index_exp[True],
|
|
|
|
np.index_exp[False],
|
|
|
|
np.index_exp[..., True],
|
|
|
|
np.index_exp[..., False],
|
|
|
|
np.index_exp[0, :2, True],
|
|
|
|
np.index_exp[0, :2, False],
|
|
|
|
np.index_exp[:2, 0, True],
|
|
|
|
np.index_exp[:2, 0, False],
|
|
|
|
np.index_exp[:2, np.array([0, 2]), True],
|
|
|
|
np.index_exp[np.array([1, 0]), :, True],
|
|
|
|
np.index_exp[True, :, True, :, np.array(True)],
|
|
|
|
]
|
|
|
|
)
|
|
|
|
def testScalarBoolUpdate(self, shape, idx, update_ndim):
|
|
|
|
update_shape = np.zeros(shape)[idx].shape[-update_ndim:]
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, np.int32), rng(update_shape, np.int32)]
|
|
|
|
def np_fun(x, update):
|
|
|
|
x = np.array(x, copy=True)
|
|
|
|
x[idx] = update
|
|
|
|
return x
|
|
|
|
jnp_fun = lambda x, update: jnp.asarray(x).at[idx].set(update)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2019-11-18 21:04:27 -05:00
|
|
|
def testFloatIndexingError(self):
|
2020-03-25 11:07:50 +02:00
|
|
|
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
|
|
|
|
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
|
|
|
jnp.zeros(2)[0.]
|
|
|
|
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
|
|
|
jnp.zeros((2, 2))[(0, 0.)]
|
|
|
|
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
|
|
|
jnp.zeros((2, 2))[(0, 0.)]
|
|
|
|
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
2021-09-13 16:00:22 -04:00
|
|
|
jax.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
|
2020-03-25 11:07:50 +02:00
|
|
|
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
[JAX] Replace uses of deprecated `jax.ops.index_update(x, idx, y)` APIs with their up-to-date, more succinct equivalent `x.at[idx].set(y)`.
The JAX operators:
jax.ops.index_update(x, jax.ops.index[idx], y)
jax.ops.index_add(x, jax.ops.index[idx], y)
...
have long been deprecated in lieu of their more succinct counterparts:
x.at[idx].set(y)
x.at[idx].add(y)
...
This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX.
The main subtlety is that if `x` is not a JAX array, we must cast it to one using `jnp.asarray(x)` before using the new form, since `.at[...]` is only defined on JAX arrays.
PiperOrigin-RevId: 400209692
2021-10-01 08:39:11 -07:00
|
|
|
jnp.zeros(2).at[0.].add(1.)
|
2020-03-25 11:07:50 +02:00
|
|
|
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
[JAX] Replace uses of deprecated `jax.ops.index_update(x, idx, y)` APIs with their up-to-date, more succinct equivalent `x.at[idx].set(y)`.
The JAX operators:
jax.ops.index_update(x, jax.ops.index[idx], y)
jax.ops.index_add(x, jax.ops.index[idx], y)
...
have long been deprecated in lieu of their more succinct counterparts:
x.at[idx].set(y)
x.at[idx].add(y)
...
This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX.
The main subtlety is that if `x` is not a JAX array, we must cast it to one using `jnp.asarray(x)` before using the new form, since `.at[...]` is only defined on JAX arrays.
PiperOrigin-RevId: 400209692
2021-10-01 08:39:11 -07:00
|
|
|
jnp.zeros(2).at[0.].set(1.)
|
2025-02-28 15:04:07 -08:00
|
|
|
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
|
|
|
jnp.zeros((2, 2))[jnp.arange(2), 1.0]
|
|
|
|
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
|
|
|
jnp.zeros((2, 2))[jnp.arange(2), 1 + 1j]
|
2020-03-25 11:07:50 +02:00
|
|
|
|
2023-03-20 08:55:16 -07:00
|
|
|
def testStrIndexingError(self):
|
|
|
|
msg = "JAX does not support string indexing"
|
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
|
|
jnp.zeros(2)['abc']
|
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
2023-11-21 15:52:35 -08:00
|
|
|
jnp.zeros((2, 3))[:, 'abc']
|
2023-03-20 08:55:16 -07:00
|
|
|
|
2024-09-20 07:51:48 -07:00
|
|
|
def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245
|
2021-08-11 17:32:36 -04:00
|
|
|
x = jnp.arange(5, dtype=jnp.int32) + 1
|
|
|
|
self.assertAllClose(x, x[:10])
|
|
|
|
|
|
|
|
idx = jnp.array([-10, -6, -5, -4, 0, 3, 4, 5, 6, 100])
|
|
|
|
self.assertArraysEqual(
|
|
|
|
x.at[idx].get(mode="clip"),
|
|
|
|
jnp.array([1, 1, 1, 2, 1, 4, 5, 5, 5, 5], jnp.int32))
|
|
|
|
nan = np.nan
|
|
|
|
self.assertArraysEqual(
|
|
|
|
x.astype(jnp.float32).at[idx].get(mode="fill"),
|
|
|
|
jnp.array([nan, nan, 1, 2, 1, 4, 5, nan, nan, nan], jnp.float32))
|
|
|
|
imin = np.iinfo(np.int32).min
|
|
|
|
self.assertArraysEqual(
|
|
|
|
x.at[idx].get(mode="fill"),
|
|
|
|
jnp.array([imin, imin, 1, 2, 1, 4, 5, imin, imin, imin], jnp.int32))
|
|
|
|
umax = np.iinfo(np.uint32).max
|
|
|
|
self.assertArraysEqual(
|
|
|
|
x.astype(np.uint32).at[idx].get(mode="fill"),
|
|
|
|
jnp.array([umax, umax, 1, 2, 1, 4, 5, umax, umax, umax], jnp.uint32))
|
|
|
|
self.assertArraysEqual(
|
|
|
|
x.at[idx].get(mode="fill", fill_value=7),
|
|
|
|
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))
|
2020-02-17 18:25:07 +01:00
|
|
|
|
2021-11-30 15:43:06 -08:00
|
|
|
def testIndexingWeakTypes(self):
|
2022-06-09 15:21:49 -07:00
|
|
|
x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True)
|
2021-11-30 15:43:06 -08:00
|
|
|
|
|
|
|
a = x.at[0].set(1.0)
|
|
|
|
self.assertEqual(a.dtype, x.dtype)
|
|
|
|
self.assertTrue(dtypes.is_weakly_typed(a))
|
|
|
|
|
|
|
|
b = x.at[0].add(1.0)
|
|
|
|
self.assertEqual(b.dtype, x.dtype)
|
|
|
|
self.assertTrue(dtypes.is_weakly_typed(b))
|
|
|
|
|
|
|
|
c = x.at[0].mul(1.0)
|
|
|
|
self.assertEqual(c.dtype, x.dtype)
|
|
|
|
self.assertTrue(dtypes.is_weakly_typed(c))
|
|
|
|
|
2022-06-09 15:21:49 -07:00
|
|
|
def testIndexingTypePromotion(self):
|
|
|
|
def _check(x_type, y_type):
|
|
|
|
x = jnp.arange(5, dtype=x_type)
|
|
|
|
y = y_type(0)
|
|
|
|
out = x.at[0].set(y)
|
|
|
|
self.assertEqual(x.dtype, out.dtype)
|
|
|
|
|
2023-08-07 19:08:41 +02:00
|
|
|
@jtu.ignore_warning(category=NumpyComplexWarning,
|
2022-06-09 15:21:49 -07:00
|
|
|
message="Casting complex values to real")
|
|
|
|
def _check_warns(x_type, y_type, msg):
|
|
|
|
with self.assertWarnsRegex(FutureWarning, msg):
|
|
|
|
_check(x_type, y_type)
|
|
|
|
|
|
|
|
def _check_raises(x_type, y_type, msg):
|
|
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
|
|
_check(x_type, y_type)
|
|
|
|
|
|
|
|
# Matching dtypes are always OK
|
|
|
|
_check(jnp.int32, jnp.int32)
|
|
|
|
_check(jnp.float32, jnp.float32)
|
|
|
|
_check(jnp.complex64, jnp.complex64)
|
|
|
|
|
|
|
|
# Weakly-typed y values promote.
|
|
|
|
_check(jnp.int32, int)
|
|
|
|
_check(jnp.float32, int)
|
|
|
|
_check(jnp.float32, float)
|
|
|
|
_check(jnp.complex64, int)
|
|
|
|
_check(jnp.complex64, float)
|
|
|
|
_check(jnp.complex64, complex)
|
|
|
|
|
|
|
|
# in standard promotion mode, strong types can promote.
|
|
|
|
msg = "scatter inputs have incompatible types"
|
|
|
|
with jax.numpy_dtype_promotion('standard'):
|
|
|
|
_check(jnp.int32, jnp.int16)
|
|
|
|
_check(jnp.float32, jnp.float16)
|
|
|
|
_check(jnp.float32, jnp.int32)
|
|
|
|
_check(jnp.complex64, jnp.int32)
|
|
|
|
_check(jnp.complex64, jnp.float32)
|
|
|
|
|
|
|
|
# TODO(jakevdp): make these _check_raises
|
|
|
|
_check_warns(jnp.int16, jnp.int32, msg)
|
|
|
|
_check_warns(jnp.int32, jnp.float32, msg)
|
|
|
|
_check_warns(jnp.int32, jnp.complex64, msg)
|
|
|
|
_check_warns(jnp.float16, jnp.float32, msg)
|
|
|
|
_check_warns(jnp.float32, jnp.complex64, msg)
|
|
|
|
|
|
|
|
# in strict promotion mode, strong types do not promote.
|
|
|
|
msg = "Input dtypes .* have no available implicit dtype promotion path"
|
|
|
|
with jax.numpy_dtype_promotion('strict'):
|
|
|
|
_check_raises(jnp.int32, jnp.int16, msg)
|
|
|
|
_check_raises(jnp.float32, jnp.float16, msg)
|
|
|
|
_check_raises(jnp.float32, jnp.int32, msg)
|
|
|
|
_check_raises(jnp.complex64, jnp.int32, msg)
|
|
|
|
_check_raises(jnp.complex64, jnp.float32, msg)
|
|
|
|
|
|
|
|
_check_raises(jnp.int16, jnp.int32, msg)
|
|
|
|
_check_raises(jnp.int32, jnp.float32, msg)
|
|
|
|
_check_raises(jnp.int32, jnp.complex64, msg)
|
|
|
|
_check_raises(jnp.float16, jnp.float32, msg)
|
|
|
|
_check_raises(jnp.float32, jnp.complex64, msg)
|
|
|
|
|
2023-11-21 15:52:35 -08:00
|
|
|
def testWrongNumberOfIndices(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
IndexError,
|
2024-07-31 13:57:48 -07:00
|
|
|
"Too many indices: 0-dimensional array indexed with 1 regular index."):
|
|
|
|
jnp.array(1)[0]
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
IndexError,
|
|
|
|
"Too many indices: 1-dimensional array indexed with 2 regular indices."):
|
2023-11-21 15:52:35 -08:00
|
|
|
jnp.zeros(3)[:, 5]
|
|
|
|
|
2020-02-17 18:25:07 +01:00
|
|
|
|
2019-02-22 07:55:36 -05:00
|
|
|
def _broadcastable_shapes(shape):
|
|
|
|
"""Returns all shapes that broadcast to `shape`."""
|
|
|
|
def f(rshape):
|
|
|
|
yield []
|
|
|
|
if rshape:
|
2019-03-04 15:49:03 -05:00
|
|
|
for s in f(rshape[1:]):
|
|
|
|
yield rshape[0:1] + s
|
2019-02-22 07:55:36 -05:00
|
|
|
if rshape[0] != 1:
|
2019-03-04 15:49:03 -05:00
|
|
|
for s in f(rshape[1:]):
|
|
|
|
yield [1] + s
|
|
|
|
for x in f(list(reversed(shape))):
|
|
|
|
yield list(reversed(x))
|
2019-02-22 07:55:36 -05:00
|
|
|
|
2020-11-18 16:13:34 -08:00
|
|
|
|
2022-06-09 15:21:49 -07:00
|
|
|
# TODO(jakevdp): move this implementation to jax.dtypes & use in scatter?
|
|
|
|
def _can_cast(from_, to):
|
2022-06-16 14:00:11 -07:00
|
|
|
with jax.numpy_dtype_promotion('standard'):
|
|
|
|
return lax.dtype(to) == dtypes.result_type(from_, to)
|
2022-06-09 15:21:49 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _compatible_dtypes(op, dtype, inexact=False):
|
2024-10-03 00:26:44 -07:00
|
|
|
if op == UpdateOps.ADD or op == UpdateOps.SUB:
|
2022-06-09 15:21:49 -07:00
|
|
|
return [dtype]
|
|
|
|
elif inexact:
|
|
|
|
return [dt for dt in float_dtypes if _can_cast(dt, dtype)]
|
|
|
|
else:
|
|
|
|
return [dt for dt in all_dtypes if _can_cast(dt, dtype)]
|
|
|
|
|
|
|
|
|
2019-02-22 07:55:36 -05:00
|
|
|
class UpdateOps(enum.Enum):
|
|
|
|
UPDATE = 0
|
|
|
|
ADD = 1
|
2024-10-03 00:26:44 -07:00
|
|
|
SUB = 2
|
|
|
|
MUL = 3
|
|
|
|
DIV = 4
|
|
|
|
POW = 5
|
|
|
|
MIN = 6
|
|
|
|
MAX = 7
|
2019-06-21 19:31:41 -07:00
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
def np_fn(op, indexer, x, y):
|
2019-06-21 19:31:41 -07:00
|
|
|
x = x.copy()
|
|
|
|
x[indexer] = {
|
|
|
|
UpdateOps.UPDATE: lambda: y,
|
|
|
|
UpdateOps.ADD: lambda: x[indexer] + y,
|
2024-10-03 00:26:44 -07:00
|
|
|
UpdateOps.SUB: lambda: x[indexer] - y,
|
2020-04-13 16:16:34 -04:00
|
|
|
UpdateOps.MUL: lambda: x[indexer] * y,
|
2021-05-10 17:44:18 -04:00
|
|
|
UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)(
|
|
|
|
lambda: x[indexer] / y.astype(x.dtype)),
|
|
|
|
UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)(
|
|
|
|
lambda: x[indexer] ** y.astype(x.dtype)),
|
2020-07-14 13:03:24 -07:00
|
|
|
UpdateOps.MIN: lambda: np.minimum(x[indexer], y),
|
|
|
|
UpdateOps.MAX: lambda: np.maximum(x[indexer], y),
|
2019-06-21 19:31:41 -07:00
|
|
|
}[op]()
|
|
|
|
return x
|
|
|
|
|
2020-07-21 23:16:27 -07:00
|
|
|
def jax_fn(op, indexer, x, y, indices_are_sorted=False,
|
2021-08-11 17:32:36 -04:00
|
|
|
unique_indices=False, mode=None):
|
2020-04-11 20:54:04 +00:00
|
|
|
x = jnp.array(x)
|
|
|
|
return {
|
2020-04-15 01:20:30 +00:00
|
|
|
UpdateOps.UPDATE: x.at[indexer].set,
|
|
|
|
UpdateOps.ADD: x.at[indexer].add,
|
2024-10-03 00:26:44 -07:00
|
|
|
UpdateOps.SUB: x.at[indexer].subtract,
|
2021-05-10 17:44:18 -04:00
|
|
|
UpdateOps.MUL: x.at[indexer].multiply,
|
|
|
|
UpdateOps.DIV: x.at[indexer].divide,
|
|
|
|
UpdateOps.POW: x.at[indexer].power,
|
2020-04-15 01:20:30 +00:00
|
|
|
UpdateOps.MIN: x.at[indexer].min,
|
|
|
|
UpdateOps.MAX: x.at[indexer].max,
|
2020-07-21 23:16:27 -07:00
|
|
|
}[op](y, indices_are_sorted=indices_are_sorted,
|
2021-08-11 17:32:36 -04:00
|
|
|
unique_indices=unique_indices, mode=mode)
|
2020-04-11 20:54:04 +00:00
|
|
|
|
2021-05-10 17:44:18 -04:00
|
|
|
def dtypes(op):
|
|
|
|
if op == UpdateOps.UPDATE:
|
|
|
|
return all_dtypes
|
|
|
|
elif op == UpdateOps.DIV or op == UpdateOps.POW:
|
|
|
|
return jtu.dtypes.inexact
|
|
|
|
else:
|
|
|
|
return default_dtypes
|
2019-02-22 07:55:36 -05:00
|
|
|
|
2021-09-22 14:13:13 -04:00
|
|
|
def _update_tol(op):
|
|
|
|
if op == UpdateOps.POW:
|
2023-09-27 12:10:06 -07:00
|
|
|
f32_tol = 2e-4 if jtu.test_device_matches(["tpu"]) else 1e-5
|
2023-09-05 18:48:18 -07:00
|
|
|
tol = {np.float32: f32_tol, np.complex64: f32_tol, np.complex128: 1e-14}
|
2021-09-22 14:13:13 -04:00
|
|
|
else:
|
|
|
|
tol = {np.complex128: 1e-14}
|
|
|
|
return tol
|
|
|
|
|
2021-08-11 17:32:36 -04:00
|
|
|
|
2019-02-22 07:55:36 -05:00
|
|
|
class IndexedUpdateTest(jtu.JaxTestCase):
|
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
|
|
|
|
for name, index_specs in STATIC_INDEXING_TESTS
|
|
|
|
for shape, indexer, index_shape in index_specs
|
|
|
|
for update_shape in _broadcastable_shapes(index_shape)
|
|
|
|
],
|
|
|
|
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
|
|
|
for op in UpdateOps
|
|
|
|
for dtype in UpdateOps.dtypes(op)
|
|
|
|
for update_dtype in _compatible_dtypes(op, dtype)
|
|
|
|
],
|
|
|
|
mode=MODES,
|
|
|
|
)
|
|
|
|
def testStaticIndexing(self, name, shape, dtype, update_shape, update_dtype,
|
2021-08-11 17:32:36 -04:00
|
|
|
indexer, op, mode):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-02-22 07:55:36 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
2020-07-14 13:03:24 -07:00
|
|
|
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
2021-10-05 16:12:52 -04:00
|
|
|
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode)
|
2022-06-16 14:00:11 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
|
|
|
|
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
|
|
|
|
self._CompileAndCheck(jax_fn, args_maker)
|
2019-02-22 07:55:36 -05:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
|
|
|
|
for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
|
|
|
|
for shape, indexer, index_shape in index_specs
|
|
|
|
for update_shape in _broadcastable_shapes(index_shape)
|
|
|
|
],
|
|
|
|
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
|
|
|
for op in UpdateOps
|
|
|
|
for dtype in UpdateOps.dtypes(op)
|
|
|
|
for update_dtype in _compatible_dtypes(op, dtype)
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def testAdvancedIndexing(self, name, shape, dtype, update_shape, update_dtype,
|
2021-05-10 17:44:18 -04:00
|
|
|
indexer, op):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-05-03 17:47:40 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
2020-07-14 13:03:24 -07:00
|
|
|
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
2021-10-05 16:12:52 -04:00
|
|
|
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y,
|
2022-06-16 14:00:11 -07:00
|
|
|
unique_indices=True)
|
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
|
|
|
|
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
|
|
|
|
self._CompileAndCheck(jax_fn, args_maker)
|
2019-02-22 07:55:36 -05:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
|
|
|
|
for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED
|
|
|
|
for shape, indexer, index_shape in index_specs
|
|
|
|
for update_shape in _broadcastable_shapes(index_shape)
|
|
|
|
],
|
|
|
|
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
|
|
|
for op in UpdateOps
|
|
|
|
for dtype in UpdateOps.dtypes(op)
|
|
|
|
for update_dtype in _compatible_dtypes(op, dtype)
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def testAdvancedIndexingSorted(self, name, shape, dtype, update_shape,
|
|
|
|
update_dtype, indexer, op):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-21 23:16:27 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
|
|
|
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
2021-10-05 16:12:52 -04:00
|
|
|
jax_fn = lambda x, y: UpdateOps.jax_fn(
|
2021-05-10 17:44:18 -04:00
|
|
|
op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
|
2022-06-16 14:00:11 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
|
|
|
|
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True,
|
|
|
|
tol=_update_tol(op))
|
|
|
|
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
|
2020-07-21 23:16:27 -07:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
|
|
|
|
for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
|
|
|
|
for shape, indexer, index_shape in index_specs
|
|
|
|
for update_shape in _broadcastable_shapes(index_shape)
|
|
|
|
],
|
|
|
|
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
|
|
|
for op in UpdateOps
|
|
|
|
for dtype in UpdateOps.dtypes(op)
|
|
|
|
for update_dtype in _compatible_dtypes(op, dtype)
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape,
|
|
|
|
update_dtype, indexer, op):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-07-14 10:57:41 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
2020-07-14 13:03:24 -07:00
|
|
|
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
2021-10-05 16:12:52 -04:00
|
|
|
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
2022-06-16 14:00:11 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
|
|
|
|
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
|
|
|
|
self._CompileAndCheck(jax_fn, args_maker)
|
2019-07-14 10:57:41 -04:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, mode=mode, shape=shape, indexer=indexer,
|
|
|
|
update_shape=update_shape)
|
|
|
|
for mode in [None] + MODES
|
|
|
|
for name, index_specs in (
|
|
|
|
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
|
|
|
|
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
|
|
|
|
for shape, indexer, index_shape in index_specs
|
|
|
|
for update_shape in _broadcastable_shapes(index_shape)
|
|
|
|
],
|
|
|
|
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
2024-10-03 00:26:44 -07:00
|
|
|
for op in [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE]
|
2022-10-06 15:53:18 +00:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def testStaticIndexingGrads(self, name, shape, dtype, update_shape,
|
|
|
|
update_dtype, indexer, op, mode):
|
2020-10-13 10:08:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2022-04-27 12:26:17 -07:00
|
|
|
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode,
|
|
|
|
unique_indices=True)
|
2019-02-22 07:55:36 -05:00
|
|
|
x = rng(shape, dtype)
|
|
|
|
y = rng(update_shape, update_dtype)
|
2022-06-16 14:00:11 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
|
|
|
|
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
|
2019-02-22 07:55:36 -05:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases(
|
|
|
|
[dict(name=name, unique_indices=unique_indices, shape=shape,
|
|
|
|
indexer=indexer, update_shape=update_shape)
|
|
|
|
for name, index_specs in (
|
|
|
|
ADVANCED_INDEXING_TESTS_NO_REPEATS if unique_indices
|
|
|
|
else ADVANCED_INDEXING_TESTS)
|
|
|
|
for shape, indexer, index_shape in index_specs
|
|
|
|
for update_shape in _broadcastable_shapes(index_shape)
|
|
|
|
],
|
|
|
|
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
|
|
|
for op in (
|
2024-10-03 00:26:44 -07:00
|
|
|
[UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE]
|
|
|
|
if unique_indices
|
|
|
|
else [UpdateOps.ADD, UpdateOps.SUB])
|
2022-10-06 15:53:18 +00:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
|
|
|
|
],
|
|
|
|
)
|
|
|
|
for unique_indices in [False, True]
|
|
|
|
))
|
|
|
|
def testAdvancedIndexingGrads(self, name, shape, dtype, update_shape,
|
|
|
|
update_dtype, indexer, op, unique_indices):
|
2021-05-12 14:16:35 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2021-10-05 16:12:52 -04:00
|
|
|
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y,
|
2022-04-27 12:26:17 -07:00
|
|
|
unique_indices=unique_indices)
|
2021-05-12 14:16:35 -04:00
|
|
|
x = rng(shape, dtype)
|
|
|
|
y = rng(update_shape, update_dtype)
|
2022-06-16 14:00:11 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
|
|
|
|
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
|
2021-05-12 14:16:35 -04:00
|
|
|
|
2022-01-24 14:54:32 -08:00
|
|
|
def testIndexMulGradFailsIfNotUnique(self):
|
|
|
|
y = jnp.ones((10,), jnp.int32)
|
|
|
|
f = lambda x, z: x.at[y].mul(z)
|
|
|
|
|
|
|
|
x = jnp.ones((100,), jnp.float32)
|
|
|
|
z = jnp.ones((10,), jnp.float32)
|
|
|
|
with self.assertRaises(NotImplementedError,
|
|
|
|
msg="scatter_mul gradients are only implemented if "
|
|
|
|
"`unique_indices=True`"):
|
|
|
|
jax.jvp(f, (x, z), (x, z))
|
|
|
|
|
2019-05-03 17:47:40 -07:00
|
|
|
def testSegmentSumBehavior(self):
|
|
|
|
# testAdvancedIndexing compares against NumPy, and as a result doesn't check
|
|
|
|
# repeated indices. This test is just a simple manual check, based on
|
|
|
|
# https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
|
2022-06-16 14:00:11 -07:00
|
|
|
data = np.array([5, 1, 7, 2, 3, 4, 1, 3], dtype=float)
|
2020-07-14 13:03:24 -07:00
|
|
|
segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3])
|
2019-05-03 17:47:40 -07:00
|
|
|
|
2022-12-01 11:17:36 -08:00
|
|
|
ans = jnp.zeros_like(data, shape=np.max(segment_ids) + 1).at[segment_ids].add(data)
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.array([13, 2, 7, 4])
|
2019-05-03 17:47:40 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testSegmentSum(self):
|
2020-12-08 13:03:30 -08:00
|
|
|
data = jnp.array([5, 1, 7, 2, 3, 4, 1, 3])
|
|
|
|
segment_ids = jnp.array([0, 0, 0, 1, 2, 2, 3, 3])
|
2019-05-03 17:47:40 -07:00
|
|
|
|
|
|
|
# test with explicit num_segments
|
|
|
|
ans = ops.segment_sum(data, segment_ids, num_segments=4)
|
2020-12-08 13:03:30 -08:00
|
|
|
expected = jnp.array([13, 2, 7, 4])
|
2019-05-03 17:47:40 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-09-11 18:51:42 +01:00
|
|
|
# test with explicit num_segments larger than the higher index.
|
|
|
|
ans = ops.segment_sum(data, segment_ids, num_segments=5)
|
2020-12-08 13:03:30 -08:00
|
|
|
expected = jnp.array([13, 2, 7, 4, 0])
|
2020-09-11 18:51:42 +01:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-05-03 17:47:40 -07:00
|
|
|
# test without explicit num_segments
|
|
|
|
ans = ops.segment_sum(data, segment_ids)
|
2020-12-08 13:03:30 -08:00
|
|
|
expected = jnp.array([13, 2, 7, 4])
|
2019-05-03 17:47:40 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-09-11 18:51:42 +01:00
|
|
|
# test with negative segment ids and segment ids larger than num_segments,
|
|
|
|
# that will be wrapped with the `mod`.
|
2020-12-08 13:03:30 -08:00
|
|
|
segment_ids = jnp.array([0, 4, 8, 1, 2, -6, -1, 3])
|
2020-09-11 18:51:42 +01:00
|
|
|
ans = ops.segment_sum(data, segment_ids, num_segments=4)
|
2021-01-07 11:17:12 -05:00
|
|
|
expected = jnp.array([5, 2, 3, 3])
|
2020-09-11 18:51:42 +01:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2023-12-13 07:45:52 +01:00
|
|
|
# test with negative segment ids and without explicit num_segments
|
2020-09-11 18:51:42 +01:00
|
|
|
# such as num_segments is defined by the smaller index.
|
2020-12-08 13:03:30 -08:00
|
|
|
segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
|
2020-09-11 18:51:42 +01:00
|
|
|
ans = ops.segment_sum(data, segment_ids)
|
2021-01-07 11:17:12 -05:00
|
|
|
expected = jnp.array([0, 0, 0, 13, 2, 7])
|
2020-09-11 18:51:42 +01:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-11-22 13:32:25 -08:00
|
|
|
def testSegmentSumOutOfBounds(self):
|
|
|
|
def fn(data, segment_ids):
|
|
|
|
return jax.ops.segment_sum(data, segment_ids, num_segments).sum()
|
|
|
|
|
|
|
|
data = np.array([0, 0], dtype=np.float32)
|
|
|
|
num_segments = 2
|
|
|
|
segment_ids = np.array([2, 3])
|
|
|
|
val, grad = jax.value_and_grad(fn)(data, segment_ids)
|
|
|
|
self.assertAllClose(val, np.array(0., np.float32))
|
|
|
|
self.assertAllClose(grad, np.array([0., 0.], np.float32))
|
|
|
|
|
2021-04-09 12:06:51 -07:00
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases(
|
|
|
|
[dict(reducer=reducer, op=op, identity=identity)],
|
|
|
|
dtype=[np.bool_],
|
|
|
|
shape=[(8,), (7, 4), (6, 4, 2)],
|
|
|
|
bucket_size=[None, 2],
|
|
|
|
num_segments=[None, 1, 3],
|
|
|
|
)
|
2022-03-15 09:20:20 -07:00
|
|
|
for reducer, op, identity in [
|
|
|
|
(ops.segment_min, np.minimum, True),
|
|
|
|
(ops.segment_max, np.maximum, False),
|
|
|
|
]))
|
2022-10-06 15:53:18 +00:00
|
|
|
def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity,
|
|
|
|
num_segments, bucket_size):
|
2022-03-15 09:20:20 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]
|
|
|
|
|
|
|
|
if np.issubdtype(dtype, np.integer):
|
|
|
|
if np.isposinf(identity):
|
|
|
|
identity = np.iinfo(dtype).max
|
|
|
|
elif np.isneginf(identity):
|
|
|
|
identity = np.iinfo(dtype).min
|
|
|
|
|
|
|
|
jnp_fun = lambda data, segment_ids: reducer(
|
|
|
|
data, segment_ids, num_segments=num_segments, bucket_size=bucket_size)
|
|
|
|
|
|
|
|
def np_fun(data, segment_ids):
|
|
|
|
size = num_segments if num_segments is not None else (segment_ids.max() + 1)
|
|
|
|
out = np.full((size,) + shape[1:], identity, dtype)
|
|
|
|
for i, val in zip(segment_ids, data):
|
|
|
|
if 0 <= i < size:
|
|
|
|
out[i] = op(out[i], val).astype(dtype)
|
|
|
|
return out
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
if num_segments is not None:
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases(
|
|
|
|
[dict(reducer=reducer, op=op, identity=identity)],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
shape=[(8,), (7, 4), (6, 4, 2)],
|
|
|
|
bucket_size=[None, 2],
|
|
|
|
num_segments=[None, 1, 3],
|
|
|
|
)
|
2021-04-09 12:06:51 -07:00
|
|
|
for reducer, op, identity in [
|
|
|
|
(ops.segment_sum, np.add, 0),
|
|
|
|
(ops.segment_prod, np.multiply, 1),
|
|
|
|
(ops.segment_min, np.minimum, float('inf')),
|
|
|
|
(ops.segment_max, np.maximum, -float('inf')),
|
|
|
|
]))
|
|
|
|
def testSegmentReduce(self, shape, dtype, reducer, op, identity, num_segments, bucket_size):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]
|
|
|
|
|
|
|
|
if np.issubdtype(dtype, np.integer):
|
|
|
|
if np.isposinf(identity):
|
|
|
|
identity = np.iinfo(dtype).max
|
|
|
|
elif np.isneginf(identity):
|
|
|
|
identity = np.iinfo(dtype).min
|
|
|
|
|
|
|
|
jnp_fun = lambda data, segment_ids: reducer(
|
|
|
|
data, segment_ids, num_segments=num_segments, bucket_size=bucket_size)
|
|
|
|
|
|
|
|
def np_fun(data, segment_ids):
|
|
|
|
size = num_segments if num_segments is not None else (segment_ids.max() + 1)
|
|
|
|
out = np.full((size,) + shape[1:], identity, dtype)
|
|
|
|
for i, val in zip(segment_ids, data):
|
|
|
|
if 0 <= i < size:
|
|
|
|
out[i] = op(out[i], val).astype(dtype)
|
|
|
|
return out
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
if num_segments is not None:
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2020-04-22 22:08:29 -07:00
|
|
|
def testIndexDtypeError(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/2795
|
2020-04-22 22:08:29 -07:00
|
|
|
jnp.array(1) # get rid of startup warning
|
2023-11-30 10:35:24 -08:00
|
|
|
with self.assertNoWarnings():
|
2020-04-22 22:08:29 -07:00
|
|
|
jnp.zeros(5).at[::2].set(1)
|
|
|
|
|
2022-10-06 15:53:18 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(idx=idx, idx_type=idx_type)
|
|
|
|
for idx, idx_type in [
|
|
|
|
([0], "array"),
|
|
|
|
([0, 0], "array"),
|
|
|
|
([[0, 0]], "tuple"),
|
|
|
|
([0, [0, 1]], "tuple"),
|
|
|
|
([0, np.arange(2)], "tuple"),
|
|
|
|
([0, None], "tuple"),
|
|
|
|
([0, slice(None)], "tuple"),
|
|
|
|
]
|
|
|
|
],
|
|
|
|
)
|
2020-11-18 16:13:34 -08:00
|
|
|
def testIndexSequenceDeprecation(self, idx, idx_type):
|
2020-10-19 14:07:40 -07:00
|
|
|
normalize = {"array": np.array, "tuple": tuple}[idx_type]
|
2020-11-18 16:13:34 -08:00
|
|
|
msg = {"array": ARRAY_MSG, "tuple": TUPLE_MSG}[idx_type]
|
2020-10-19 14:07:40 -07:00
|
|
|
x = jnp.arange(6).reshape(3, 2)
|
|
|
|
|
2020-11-18 16:13:34 -08:00
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
|
|
x[idx]
|
2020-10-19 14:07:40 -07:00
|
|
|
with self.assertNoWarnings():
|
2020-11-18 16:13:34 -08:00
|
|
|
x[normalize(idx)]
|
2020-10-19 14:07:40 -07:00
|
|
|
|
2020-11-18 16:13:34 -08:00
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
|
|
x.at[idx].set(0)
|
2020-10-19 14:07:40 -07:00
|
|
|
with self.assertNoWarnings():
|
2020-11-18 16:13:34 -08:00
|
|
|
x.at[normalize(idx)].set(0)
|
2020-10-19 14:07:40 -07:00
|
|
|
|
2021-10-13 11:11:24 -04:00
|
|
|
def testIndexedUpdateAliasingBug(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/7461
|
2021-10-13 11:11:24 -04:00
|
|
|
fn = lambda x: x.at[1:].set(1 + x[:-1])
|
|
|
|
y = jnp.zeros(8)
|
|
|
|
self.assertArraysEqual(fn(y), jax.jit(fn)(y))
|
2019-05-03 17:47:40 -07:00
|
|
|
|
2023-04-10 14:24:26 -07:00
|
|
|
def testScatterValuesCastToTargetDType(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/15505
|
2023-04-10 14:24:26 -07:00
|
|
|
a = jnp.zeros(1, dtype=jnp.uint32)
|
|
|
|
val = 2**32 - 1 # too large for int32
|
|
|
|
|
|
|
|
b = a.at[0].set(jnp.uint32(val))
|
|
|
|
self.assertEqual(int(b[0]), val)
|
|
|
|
|
|
|
|
c = a.at[0].set(val)
|
|
|
|
self.assertEqual(int(c[0]), val)
|
|
|
|
|
2025-01-14 11:20:50 -08:00
|
|
|
def testGradOfVmapOfScatter(self):
|
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/25878
|
|
|
|
def f(x, i):
|
|
|
|
return x.at[i].get(mode='clip')
|
|
|
|
|
|
|
|
x = jnp.array([1.0])
|
|
|
|
i = jnp.array([1]) # out-of-bound index
|
|
|
|
expected = jnp.array([[1.0]])
|
|
|
|
|
|
|
|
self.assertArraysEqual(jax.jacrev(f)(x, i), expected)
|
|
|
|
self.assertArraysEqual(jax.jacrev(jax.vmap(f, (None, 0)))(x, i), expected)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|