mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Switch lax_numpy_indexing_test to use jtu.sample_product.
This commit is contained in:
parent
219c574d8f
commit
8107e3600e
@ -419,14 +419,13 @@ MODES = ["clip", "drop", "promise_in_bounds"]
|
||||
class IndexingTest(jtu.JaxTestCase):
|
||||
"""Tests for Numpy indexing translation rules."""
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
"testcase_name": "{}_inshape={}_indexer={}".format(
|
||||
name, jtu.format_shape_dtype_string( shape, dtype), indexer),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer
|
||||
} for name, index_specs in STATIC_INDEXING_TESTS
|
||||
for shape, indexer, _ in index_specs
|
||||
for dtype in all_dtypes))
|
||||
def testStaticIndexing(self, shape, dtype, indexer):
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer)
|
||||
for name, index_specs in STATIC_INDEXING_TESTS
|
||||
for shape, indexer, _ in index_specs],
|
||||
dtype=all_dtypes
|
||||
)
|
||||
def testStaticIndexing(self, name, shape, dtype, indexer):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
np_fun = lambda x: np.asarray(x)[indexer]
|
||||
@ -439,9 +438,9 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
"testcase_name": f"_{funcname}", "funcname": funcname}
|
||||
for funcname in ["negative", "sin", "cos", "square", "sqrt", "log", "exp"]))
|
||||
@jtu.sample_product(
|
||||
funcname=["negative", "sin", "cos", "square", "sqrt", "log", "exp"],
|
||||
)
|
||||
def testIndexApply(self, funcname, size=10, dtype='float32'):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
idx_rng = jtu.rand_int(self.rng(), -size, size)
|
||||
@ -459,19 +458,17 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters({
|
||||
"testcase_name":
|
||||
f"{jtu.format_shape_dtype_string(shape, dtype)}_inshape={name}"
|
||||
f"_indexer={indexer}_mode={mode}",
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer, "mode": mode
|
||||
}
|
||||
for mode in MODES
|
||||
for name, index_specs in (
|
||||
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
|
||||
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
|
||||
for shape, indexer, _ in index_specs
|
||||
for dtype in float_dtypes)
|
||||
def testStaticIndexingGrads(self, shape, dtype, indexer, mode):
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer, mode=mode)
|
||||
for mode in MODES
|
||||
for name, index_specs in (
|
||||
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
|
||||
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
|
||||
for shape, indexer, _ in index_specs
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testStaticIndexingGrads(self, name, shape, dtype, indexer, mode):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
|
||||
arg = rng(shape, dtype)
|
||||
@ -496,30 +493,30 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
else:
|
||||
return idx, lambda x: x
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "{}_inshape={}_indexer={}"
|
||||
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer}
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer)
|
||||
for name, index_specs in [
|
||||
("OneSliceIndex",
|
||||
[IndexSpec(shape=(5,), indexer=slice(1, 3)),
|
||||
IndexSpec(shape=(5, 4), indexer=slice(1, 3))]),
|
||||
("TwoSliceIndices",
|
||||
[IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
|
||||
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]),
|
||||
("NonUnitStrides", [
|
||||
IndexSpec(shape=(3,), indexer=slice(None, None, -1)),
|
||||
IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
|
||||
]),
|
||||
("OnlyStartOrStopDynamic", [
|
||||
IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
|
||||
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
|
||||
]),
|
||||
("OneSliceIndex",
|
||||
[IndexSpec(shape=(5,), indexer=slice(1, 3)),
|
||||
IndexSpec(shape=(5, 4), indexer=slice(1, 3))]),
|
||||
("TwoSliceIndices",
|
||||
[IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
|
||||
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]),
|
||||
("NonUnitStrides", [
|
||||
IndexSpec(shape=(3,), indexer=slice(None, None, -1)),
|
||||
IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
|
||||
]),
|
||||
("OnlyStartOrStopDynamic", [
|
||||
IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
|
||||
IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
|
||||
]),
|
||||
]
|
||||
for shape, indexer, _ in index_specs
|
||||
for dtype in all_dtypes)
|
||||
def testDynamicIndexingWithSlicesErrors(self, shape, dtype, indexer):
|
||||
],
|
||||
dtype=all_dtypes,
|
||||
)
|
||||
def testDynamicIndexingWithSlicesErrors(self, name, shape, dtype, indexer):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
|
||||
|
||||
@ -531,10 +528,8 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
|
||||
self.assertRaises(IndexError, lambda: fun(*args_maker()))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "{}_inshape={}_indexer={}"
|
||||
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer}
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer)
|
||||
for name, index_specs in [
|
||||
("OneIntIndex",
|
||||
[IndexSpec(shape=(3,), indexer=1),
|
||||
@ -550,8 +545,10 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
[IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
|
||||
]
|
||||
for shape, indexer, _ in index_specs
|
||||
for dtype in all_dtypes)
|
||||
def testDynamicIndexingWithIntegers(self, shape, dtype, indexer):
|
||||
],
|
||||
dtype=all_dtypes,
|
||||
)
|
||||
def testDynamicIndexingWithIntegers(self, name, shape, dtype, indexer):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
|
||||
|
||||
@ -567,10 +564,8 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "{}_inshape={}_indexer={}"
|
||||
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer}
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer)
|
||||
for name, index_specs in [
|
||||
("OneIntIndex",
|
||||
[IndexSpec(shape=(3,), indexer=1),
|
||||
@ -588,8 +583,10 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
[IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
|
||||
]
|
||||
for shape, indexer, _ in index_specs
|
||||
for dtype in float_dtypes)
|
||||
def testDynamicIndexingWithIntegersGrads(self, shape, dtype, indexer):
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testDynamicIndexingWithIntegersGrads(self, name, shape, dtype, indexer):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
|
||||
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)
|
||||
@ -602,14 +599,14 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
arr = rng(shape, dtype)
|
||||
check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "{}_inshape={}_indexer={}"
|
||||
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer}
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer)
|
||||
for name, index_specs in ADVANCED_INDEXING_TESTS
|
||||
for shape, indexer, _ in index_specs
|
||||
for dtype in all_dtypes)
|
||||
def testAdvancedIntegerIndexing(self, shape, dtype, indexer):
|
||||
],
|
||||
dtype=all_dtypes,
|
||||
)
|
||||
def testAdvancedIntegerIndexing(self, name, shape, dtype, indexer):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), indexer]
|
||||
np_fun = lambda x, idx: np.asarray(x)[idx]
|
||||
@ -617,9 +614,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{dtype}", "dtype": dtype}
|
||||
for dtype in jtu.dtypes.unsigned + jtu.dtypes.integer)
|
||||
@jtu.sample_product(dtype=jtu.dtypes.unsigned + jtu.dtypes.integer)
|
||||
def testIndicesNormalizationByType(self, dtype):
|
||||
x = jnp.arange(10)
|
||||
jaxpr = jax.make_jaxpr(x.__getitem__)(jnp.arange(3, dtype=dtype))
|
||||
@ -633,10 +628,8 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(primitives[:3], [lax.lt_p, lax.add_p, lax.select_n_p])
|
||||
self.assertEqual(primitives[-2:], [lax.broadcast_in_dim_p, lax.gather_p])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "{}_inshape={}_indexer={}"
|
||||
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer}
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer)
|
||||
for name, index_specs in [
|
||||
("One1DIntArrayIndex",
|
||||
[IndexSpec(shape=(3,), indexer=np.array([0, 1])),
|
||||
@ -676,22 +669,24 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
]),
|
||||
]
|
||||
for shape, indexer, _ in index_specs
|
||||
for dtype in float_dtypes)
|
||||
def testAdvancedIntegerIndexingGrads(self, shape, dtype, indexer):
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testAdvancedIntegerIndexingGrads(self, name, shape, dtype, indexer):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
|
||||
arg = rng(shape, dtype)
|
||||
fun = lambda x: jnp.asarray(x)[indexer]
|
||||
check_grads(fun, (arg,), 2, tol, tol, eps=1.)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "{}_inshape={}_indexer={}"
|
||||
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer}
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer)
|
||||
for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
|
||||
for shape, indexer, _ in index_specs
|
||||
for dtype in all_dtypes)
|
||||
def testMixedAdvancedIntegerIndexing(self, shape, dtype, indexer):
|
||||
],
|
||||
dtype=all_dtypes,
|
||||
)
|
||||
def testMixedAdvancedIntegerIndexing(self, name, shape, dtype, indexer):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
indexer_with_dummies = [e if isinstance(e, np.ndarray) else ()
|
||||
for e in indexer]
|
||||
@ -1113,7 +1108,7 @@ class UpdateOps(enum.Enum):
|
||||
|
||||
def _update_tol(op):
|
||||
if op == UpdateOps.POW:
|
||||
tol = {np.complex64: 1e-4 if jtu.device_under_test() == "tpu" else 1e-5,
|
||||
tol = {np.complex64: 2e-4 if jtu.device_under_test() == "tpu" else 1e-5,
|
||||
np.complex128: 1e-14}
|
||||
else:
|
||||
tol = {np.complex128: 1e-14}
|
||||
@ -1122,23 +1117,20 @@ def _update_tol(op):
|
||||
|
||||
class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name":
|
||||
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_indexer={indexer}"
|
||||
f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}"
|
||||
f"_op={op.name}",
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op, "mode": mode,
|
||||
} for name, index_specs in s(STATIC_INDEXING_TESTS)
|
||||
for shape, indexer, update_shape in s(index_specs)
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype))
|
||||
for mode in s(MODES))))
|
||||
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
|
||||
for name, index_specs in STATIC_INDEXING_TESTS
|
||||
for shape, indexer, index_shape in index_specs
|
||||
for update_shape in _broadcastable_shapes(index_shape)
|
||||
],
|
||||
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
||||
for op in UpdateOps
|
||||
for dtype in UpdateOps.dtypes(op)
|
||||
for update_dtype in _compatible_dtypes(op, dtype)
|
||||
],
|
||||
mode=MODES,
|
||||
)
|
||||
def testStaticIndexing(self, name, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op, mode):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
@ -1148,20 +1140,19 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op
|
||||
} for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS)
|
||||
for shape, indexer, update_shape in s(index_specs)
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype)))))
|
||||
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
|
||||
for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
|
||||
for shape, indexer, index_shape in index_specs
|
||||
for update_shape in _broadcastable_shapes(index_shape)
|
||||
],
|
||||
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
||||
for op in UpdateOps
|
||||
for dtype in UpdateOps.dtypes(op)
|
||||
for update_dtype in _compatible_dtypes(op, dtype)
|
||||
],
|
||||
)
|
||||
def testAdvancedIndexing(self, name, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
@ -1172,21 +1163,20 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op
|
||||
} for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED)
|
||||
for shape, indexer, update_shape in s(index_specs)
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype)))))
|
||||
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op):
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
|
||||
for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED
|
||||
for shape, indexer, index_shape in index_specs
|
||||
for update_shape in _broadcastable_shapes(index_shape)
|
||||
],
|
||||
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
||||
for op in UpdateOps
|
||||
for dtype in UpdateOps.dtypes(op)
|
||||
for update_dtype in _compatible_dtypes(op, dtype)
|
||||
],
|
||||
)
|
||||
def testAdvancedIndexingSorted(self, name, shape, dtype, update_shape,
|
||||
update_dtype, indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
||||
@ -1197,21 +1187,20 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
tol=_update_tol(op))
|
||||
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op
|
||||
} for name, index_specs in s(MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS)
|
||||
for shape, indexer, update_shape in s(index_specs)
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype)))))
|
||||
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op):
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, shape=shape, indexer=indexer, update_shape=update_shape)
|
||||
for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
|
||||
for shape, indexer, index_shape in index_specs
|
||||
for update_shape in _broadcastable_shapes(index_shape)
|
||||
],
|
||||
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
||||
for op in UpdateOps
|
||||
for dtype in UpdateOps.dtypes(op)
|
||||
for update_dtype in _compatible_dtypes(op, dtype)
|
||||
],
|
||||
)
|
||||
def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape,
|
||||
update_dtype, indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
||||
@ -1220,26 +1209,24 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
"testcase_name":
|
||||
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_indexer={indexer}"
|
||||
f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}"
|
||||
f"_op={op.name}_mode={mode}",
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op, "mode": mode,
|
||||
} for mode in [None] + MODES
|
||||
for name, index_specs in (
|
||||
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
|
||||
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
|
||||
for shape, indexer, update_shape in index_specs
|
||||
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
|
||||
for dtype in float_dtypes
|
||||
for update_shape in _broadcastable_shapes(update_shape)
|
||||
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)))
|
||||
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op, mode):
|
||||
@jtu.sample_product(
|
||||
[dict(name=name, mode=mode, shape=shape, indexer=indexer,
|
||||
update_shape=update_shape)
|
||||
for mode in [None] + MODES
|
||||
for name, index_specs in (
|
||||
STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else
|
||||
STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS)
|
||||
for shape, indexer, index_shape in index_specs
|
||||
for update_shape in _broadcastable_shapes(index_shape)
|
||||
],
|
||||
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
||||
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
|
||||
for dtype in float_dtypes
|
||||
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
|
||||
],
|
||||
)
|
||||
def testStaticIndexingGrads(self, name, shape, dtype, update_shape,
|
||||
update_dtype, indexer, op, mode):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode,
|
||||
unique_indices=True)
|
||||
@ -1248,26 +1235,28 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
with jtu.strict_promotion_if_dtypes_match([dtype, update_dtype]):
|
||||
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op, "unique_indices": unique_indices,
|
||||
} for unique_indices in s([False, True])
|
||||
for name, index_specs in s(
|
||||
ADVANCED_INDEXING_TESTS_NO_REPEATS if unique_indices
|
||||
else ADVANCED_INDEXING_TESTS)
|
||||
for shape, indexer, update_shape in s(index_specs)
|
||||
for op in s(
|
||||
[UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices
|
||||
else [UpdateOps.ADD])
|
||||
for dtype in s(float_dtypes)
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype, inexact=True)))))
|
||||
def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op, unique_indices):
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=name, unique_indices=unique_indices, shape=shape,
|
||||
indexer=indexer, update_shape=update_shape)
|
||||
for name, index_specs in (
|
||||
ADVANCED_INDEXING_TESTS_NO_REPEATS if unique_indices
|
||||
else ADVANCED_INDEXING_TESTS)
|
||||
for shape, indexer, index_shape in index_specs
|
||||
for update_shape in _broadcastable_shapes(index_shape)
|
||||
],
|
||||
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
||||
for op in (
|
||||
[UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices
|
||||
else [UpdateOps.ADD])
|
||||
for dtype in float_dtypes
|
||||
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
|
||||
],
|
||||
)
|
||||
for unique_indices in [False, True]
|
||||
))
|
||||
def testAdvancedIndexingGrads(self, name, shape, dtype, update_shape,
|
||||
update_dtype, indexer, op, unique_indices):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y,
|
||||
unique_indices=unique_indices)
|
||||
@ -1343,23 +1332,20 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(grad, np.array([0., 0.], np.float32))
|
||||
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list({
|
||||
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
reducer.__name__, num_segments, bucket_size),
|
||||
"dtype": dtype, "shape": shape,
|
||||
"reducer": reducer, "op": op, "identity": identity,
|
||||
"num_segments": num_segments, "bucket_size": bucket_size}
|
||||
for dtype in [np.bool_]
|
||||
for shape in [(8,), (7, 4), (6, 4, 2)]
|
||||
for bucket_size in [None, 2]
|
||||
for num_segments in [None, 1, 3])
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(reducer=reducer, op=op, identity=identity)],
|
||||
dtype=[np.bool_],
|
||||
shape=[(8,), (7, 4), (6, 4, 2)],
|
||||
bucket_size=[None, 2],
|
||||
num_segments=[None, 1, 3],
|
||||
)
|
||||
for reducer, op, identity in [
|
||||
(ops.segment_min, np.minimum, True),
|
||||
(ops.segment_max, np.maximum, False),
|
||||
]))
|
||||
def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity, num_segments, bucket_size):
|
||||
def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity,
|
||||
num_segments, bucket_size):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
|
||||
args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]
|
||||
@ -1386,18 +1372,14 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list({
|
||||
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
reducer.__name__, num_segments, bucket_size),
|
||||
"dtype": dtype, "shape": shape,
|
||||
"reducer": reducer, "op": op, "identity": identity,
|
||||
"num_segments": num_segments, "bucket_size": bucket_size}
|
||||
for dtype in default_dtypes
|
||||
for shape in [(8,), (7, 4), (6, 4, 2)]
|
||||
for bucket_size in [None, 2]
|
||||
for num_segments in [None, 1, 3])
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(reducer=reducer, op=op, identity=identity)],
|
||||
dtype=default_dtypes,
|
||||
shape=[(8,), (7, 4), (6, 4, 2)],
|
||||
bucket_size=[None, 2],
|
||||
num_segments=[None, 1, 3],
|
||||
)
|
||||
for reducer, op, identity in [
|
||||
(ops.segment_sum, np.add, 0),
|
||||
(ops.segment_prod, np.multiply, 1),
|
||||
@ -1444,17 +1426,19 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
yield
|
||||
self.assertEmpty(caught_warnings)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
"testcase_name": f"idx={idx}", "idx": idx, "idx_type": idx_type}
|
||||
for idx, idx_type in [
|
||||
([0], "array"),
|
||||
([0, 0], "array"),
|
||||
([[0, 0]], "tuple"),
|
||||
([0, [0, 1]], "tuple"),
|
||||
([0, np.arange(2)], "tuple"),
|
||||
([0, None], "tuple"),
|
||||
([0, slice(None)], "tuple"),
|
||||
]))
|
||||
@jtu.sample_product(
|
||||
[dict(idx=idx, idx_type=idx_type)
|
||||
for idx, idx_type in [
|
||||
([0], "array"),
|
||||
([0, 0], "array"),
|
||||
([[0, 0]], "tuple"),
|
||||
([0, [0, 1]], "tuple"),
|
||||
([0, np.arange(2)], "tuple"),
|
||||
([0, None], "tuple"),
|
||||
([0, slice(None)], "tuple"),
|
||||
]
|
||||
],
|
||||
)
|
||||
def testIndexSequenceDeprecation(self, idx, idx_type):
|
||||
normalize = {"array": np.array, "tuple": tuple}[idx_type]
|
||||
msg = {"array": ARRAY_MSG, "tuple": TUPLE_MSG}[idx_type]
|
||||
|
Loading…
x
Reference in New Issue
Block a user