Switch lax_numpy_indexing_test to use jtu.sample_product.

This commit is contained in:
Peter Hawkins 2022-10-06 15:53:18 +00:00
parent 219c574d8f
commit 8107e3600e

View File

@ -419,14 +419,13 @@ MODES = ["clip", "drop", "promise_in_bounds"]
class IndexingTest(jtu.JaxTestCase):
"""Tests for Numpy indexing translation rules."""
@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": "{}_inshape={}_indexer={}".format(
name, jtu.format_shape_dtype_string( shape, dtype), indexer),
"shape": shape, "dtype": dtype, "indexer": indexer
} for name, index_specs in STATIC_INDEXING_TESTS
for shape, indexer, _ in index_specs
for dtype in all_dtypes))
def testStaticIndexing(self, shape, dtype, indexer):
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in STATIC_INDEXING_TESTS
for shape, indexer, _ in index_specs],
dtype=all_dtypes
)
def testStaticIndexing(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda x: np.asarray(x)[indexer]
@ -439,9 +438,9 @@ class IndexingTest(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": f"_{funcname}", "funcname": funcname}
for funcname in ["negative", "sin", "cos", "square", "sqrt", "log", "exp"]))
@jtu.sample_product(
funcname=["negative", "sin", "cos", "square", "sqrt", "log", "exp"],
)
def testIndexApply(self, funcname, size=10, dtype='float32'):
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), -size, size)
@ -459,19 +458,17 @@ class IndexingTest(jtu.JaxTestCase):
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters({
"testcase_name":
f"{jtu.format_shape_dtype_string(shape, dtype)}_inshape={name}"
f"_indexer={indexer}_mode={mode}",
"shape": shape, "dtype": dtype, "indexer": indexer, "mode": mode
}
for mode in MODES
for name, index_specs in (
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
for shape, indexer, _ in index_specs
for dtype in float_dtypes)
def testStaticIndexingGrads(self, shape, dtype, indexer, mode):
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, mode=mode)
for mode in MODES
for name, index_specs in (
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
for shape, indexer, _ in index_specs
],
dtype=float_dtypes,
)
def testStaticIndexingGrads(self, name, shape, dtype, indexer, mode):
rng = jtu.rand_default(self.rng())
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
arg = rng(shape, dtype)
@ -496,30 +493,30 @@ class IndexingTest(jtu.JaxTestCase):
else:
return idx, lambda x: x
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "indexer": indexer}
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in [
("OneSliceIndex",
[IndexSpec(shape=(5,), indexer=slice(1, 3)),
IndexSpec(shape=(5, 4), indexer=slice(1, 3))]),
("TwoSliceIndices",
[IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]),
("NonUnitStrides", [
IndexSpec(shape=(3,), indexer=slice(None, None, -1)),
IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
]),
("OnlyStartOrStopDynamic", [
IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
]),
("OneSliceIndex",
[IndexSpec(shape=(5,), indexer=slice(1, 3)),
IndexSpec(shape=(5, 4), indexer=slice(1, 3))]),
("TwoSliceIndices",
[IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]),
("NonUnitStrides", [
IndexSpec(shape=(3,), indexer=slice(None, None, -1)),
IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
]),
("OnlyStartOrStopDynamic", [
IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
]),
]
for shape, indexer, _ in index_specs
for dtype in all_dtypes)
def testDynamicIndexingWithSlicesErrors(self, shape, dtype, indexer):
],
dtype=all_dtypes,
)
def testDynamicIndexingWithSlicesErrors(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
@ -531,10 +528,8 @@ class IndexingTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
self.assertRaises(IndexError, lambda: fun(*args_maker()))
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "indexer": indexer}
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in [
("OneIntIndex",
[IndexSpec(shape=(3,), indexer=1),
@ -550,8 +545,10 @@ class IndexingTest(jtu.JaxTestCase):
[IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
]
for shape, indexer, _ in index_specs
for dtype in all_dtypes)
def testDynamicIndexingWithIntegers(self, shape, dtype, indexer):
],
dtype=all_dtypes,
)
def testDynamicIndexingWithIntegers(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
@ -567,10 +564,8 @@ class IndexingTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "indexer": indexer}
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in [
("OneIntIndex",
[IndexSpec(shape=(3,), indexer=1),
@ -588,8 +583,10 @@ class IndexingTest(jtu.JaxTestCase):
[IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
]
for shape, indexer, _ in index_specs
for dtype in float_dtypes)
def testDynamicIndexingWithIntegersGrads(self, shape, dtype, indexer):
],
dtype=float_dtypes,
)
def testDynamicIndexingWithIntegersGrads(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
@ -602,14 +599,14 @@ class IndexingTest(jtu.JaxTestCase):
arr = rng(shape, dtype)
check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "indexer": indexer}
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in ADVANCED_INDEXING_TESTS
for shape, indexer, _ in index_specs
for dtype in all_dtypes)
def testAdvancedIntegerIndexing(self, shape, dtype, indexer):
],
dtype=all_dtypes,
)
def testAdvancedIntegerIndexing(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), indexer]
np_fun = lambda x, idx: np.asarray(x)[idx]
@ -617,9 +614,7 @@ class IndexingTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(
{"testcase_name": f"_{dtype}", "dtype": dtype}
for dtype in jtu.dtypes.unsigned + jtu.dtypes.integer)
@jtu.sample_product(dtype=jtu.dtypes.unsigned + jtu.dtypes.integer)
def testIndicesNormalizationByType(self, dtype):
x = jnp.arange(10)
jaxpr = jax.make_jaxpr(x.__getitem__)(jnp.arange(3, dtype=dtype))
@ -633,10 +628,8 @@ class IndexingTest(jtu.JaxTestCase):
self.assertEqual(primitives[:3], [lax.lt_p, lax.add_p, lax.select_n_p])
self.assertEqual(primitives[-2:], [lax.broadcast_in_dim_p, lax.gather_p])
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "indexer": indexer}
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in [
("One1DIntArrayIndex",
[IndexSpec(shape=(3,), indexer=np.array([0, 1])),
@ -676,22 +669,24 @@ class IndexingTest(jtu.JaxTestCase):
]),
]
for shape, indexer, _ in index_specs
for dtype in float_dtypes)
def testAdvancedIntegerIndexingGrads(self, shape, dtype, indexer):
],
dtype=float_dtypes,
)
def testAdvancedIntegerIndexingGrads(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
arg = rng(shape, dtype)
fun = lambda x: jnp.asarray(x)[indexer]
check_grads(fun, (arg,), 2, tol, tol, eps=1.)
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "indexer": indexer}
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer)
for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
for shape, indexer, _ in index_specs
for dtype in all_dtypes)
def testMixedAdvancedIntegerIndexing(self, shape, dtype, indexer):
],
dtype=all_dtypes,
)
def testMixedAdvancedIntegerIndexing(self, name, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
indexer_with_dummies = [e if isinstance(e, np.ndarray) else ()
for e in indexer]
@ -1113,7 +1108,7 @@ class UpdateOps(enum.Enum):
def _update_tol(op):
if op == UpdateOps.POW:
tol = {np.complex64: 1e-4 if jtu.device_under_test() == "tpu" else 1e-5,
tol = {np.complex64: 2e-4 if jtu.device_under_test() == "tpu" else 1e-5,
np.complex128: 1e-14}
else:
tol = {np.complex128: 1e-14}
@ -1122,23 +1117,20 @@ def _update_tol(op):
class IndexedUpdateTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name":
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}"
f"_indexer={indexer}"
f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}"
f"_op={op.name}",
"shape": shape, "dtype": dtype, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op, "mode": mode,
} for name, index_specs in s(STATIC_INDEXING_TESTS)
for shape, indexer, update_shape in s(index_specs)
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s(_compatible_dtypes(op, dtype))
for mode in s(MODES))))
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
for name, index_specs in STATIC_INDEXING_TESTS
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in UpdateOps
for dtype in UpdateOps.dtypes(op)
for update_dtype in _compatible_dtypes(op, dtype)
],
mode=MODES,
)
def testStaticIndexing(self, name, shape, dtype, update_shape, update_dtype,
indexer, op, mode):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
@ -1148,20 +1140,19 @@ class IndexedUpdateTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
"shape": shape, "dtype": dtype, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op
} for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS)
for shape, indexer, update_shape in s(index_specs)
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in UpdateOps
for dtype in UpdateOps.dtypes(op)
for update_dtype in _compatible_dtypes(op, dtype)
],
)
def testAdvancedIndexing(self, name, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
@ -1172,21 +1163,20 @@ class IndexedUpdateTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
"shape": shape, "dtype": dtype, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op
} for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED)
for shape, indexer, update_shape in s(index_specs)
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
indexer, op):
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in UpdateOps
for dtype in UpdateOps.dtypes(op)
for update_dtype in _compatible_dtypes(op, dtype)
],
)
def testAdvancedIndexingSorted(self, name, shape, dtype, update_shape,
update_dtype, indexer, op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
@ -1197,21 +1187,20 @@ class IndexedUpdateTest(jtu.JaxTestCase):
tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
"shape": shape, "dtype": dtype, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op
} for name, index_specs in s(MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS)
for shape, indexer, update_shape in s(index_specs)
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op):
@jtu.sample_product(
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in UpdateOps
for dtype in UpdateOps.dtypes(op)
for update_dtype in _compatible_dtypes(op, dtype)
],
)
def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape,
update_dtype, indexer, op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
@ -1220,26 +1209,24 @@ class IndexedUpdateTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker)
@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name":
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}"
f"_indexer={indexer}"
f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}"
f"_op={op.name}_mode={mode}",
"shape": shape, "dtype": dtype, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op, "mode": mode,
} for mode in [None] + MODES
for name, index_specs in (
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
for shape, indexer, update_shape in index_specs
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
for dtype in float_dtypes
for update_shape in _broadcastable_shapes(update_shape)
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)))
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op, mode):
@jtu.sample_product(
[dict(name=name, mode=mode, shape=shape, indexer=indexer,
update_shape=update_shape)
for mode in [None] + MODES
for name, index_specs in (
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
for dtype in float_dtypes
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
],
)
def testStaticIndexingGrads(self, name, shape, dtype, update_shape,
update_dtype, indexer, op, mode):
rng = jtu.rand_default(self.rng())
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode,
unique_indices=True)
@ -1248,26 +1235,28 @@ class IndexedUpdateTest(jtu.JaxTestCase):
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
"shape": shape, "dtype": dtype, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op, "unique_indices": unique_indices,
} for unique_indices in s([False, True])
for name, index_specs in s(
ADVANCED_INDEXING_TESTS_NO_REPEATS if unique_indices
else ADVANCED_INDEXING_TESTS)
for shape, indexer, update_shape in s(index_specs)
for op in s(
[UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices
else [UpdateOps.ADD])
for dtype in s(float_dtypes)
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s(_compatible_dtypes(op, dtype, inexact=True)))))
def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op, unique_indices):
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=name, unique_indices=unique_indices, shape=shape,
indexer=indexer, update_shape=update_shape)
for name, index_specs in (
ADVANCED_INDEXING_TESTS_NO_REPEATS if unique_indices
else ADVANCED_INDEXING_TESTS)
for shape, indexer, index_shape in index_specs
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in (
[UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices
else [UpdateOps.ADD])
for dtype in float_dtypes
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
],
)
for unique_indices in [False, True]
))
def testAdvancedIndexingGrads(self, name, shape, dtype, update_shape,
update_dtype, indexer, op, unique_indices):
rng = jtu.rand_default(self.rng())
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y,
unique_indices=unique_indices)
@ -1343,23 +1332,20 @@ class IndexedUpdateTest(jtu.JaxTestCase):
self.assertAllClose(grad, np.array([0., 0.], np.float32))
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list({
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
jtu.format_shape_dtype_string(shape, dtype),
reducer.__name__, num_segments, bucket_size),
"dtype": dtype, "shape": shape,
"reducer": reducer, "op": op, "identity": identity,
"num_segments": num_segments, "bucket_size": bucket_size}
for dtype in [np.bool_]
for shape in [(8,), (7, 4), (6, 4, 2)]
for bucket_size in [None, 2]
for num_segments in [None, 1, 3])
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(reducer=reducer, op=op, identity=identity)],
dtype=[np.bool_],
shape=[(8,), (7, 4), (6, 4, 2)],
bucket_size=[None, 2],
num_segments=[None, 1, 3],
)
for reducer, op, identity in [
(ops.segment_min, np.minimum, True),
(ops.segment_max, np.maximum, False),
]))
def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity, num_segments, bucket_size):
def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity,
num_segments, bucket_size):
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]
@ -1386,18 +1372,14 @@ class IndexedUpdateTest(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list({
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
jtu.format_shape_dtype_string(shape, dtype),
reducer.__name__, num_segments, bucket_size),
"dtype": dtype, "shape": shape,
"reducer": reducer, "op": op, "identity": identity,
"num_segments": num_segments, "bucket_size": bucket_size}
for dtype in default_dtypes
for shape in [(8,), (7, 4), (6, 4, 2)]
for bucket_size in [None, 2]
for num_segments in [None, 1, 3])
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(reducer=reducer, op=op, identity=identity)],
dtype=default_dtypes,
shape=[(8,), (7, 4), (6, 4, 2)],
bucket_size=[None, 2],
num_segments=[None, 1, 3],
)
for reducer, op, identity in [
(ops.segment_sum, np.add, 0),
(ops.segment_prod, np.multiply, 1),
@ -1444,17 +1426,19 @@ class IndexedUpdateTest(jtu.JaxTestCase):
yield
self.assertEmpty(caught_warnings)
@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": f"idx={idx}", "idx": idx, "idx_type": idx_type}
for idx, idx_type in [
([0], "array"),
([0, 0], "array"),
([[0, 0]], "tuple"),
([0, [0, 1]], "tuple"),
([0, np.arange(2)], "tuple"),
([0, None], "tuple"),
([0, slice(None)], "tuple"),
]))
@jtu.sample_product(
[dict(idx=idx, idx_type=idx_type)
for idx, idx_type in [
([0], "array"),
([0, 0], "array"),
([[0, 0]], "tuple"),
([0, [0, 1]], "tuple"),
([0, np.arange(2)], "tuple"),
([0, None], "tuple"),
([0, slice(None)], "tuple"),
]
],
)
def testIndexSequenceDeprecation(self, idx, idx_type):
normalize = {"array": np.array, "tuple": tuple}[idx_type]
msg = {"array": ARRAY_MSG, "tuple": TUPLE_MSG}[idx_type]