diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index def9d21b0..2a1c52a5c 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -658,20 +658,6 @@ def assert_dot_precision(expected_precision, fun, *args): assert precision == expected_precision, msg -_CACHED_INDICES: Dict[int, Sequence[int]] = {} - -def cases_from_list(xs): - xs = list(xs) - n = len(xs) - k = min(n, FLAGS.num_generated_cases) - # Random sampling for every parameterized test is expensive. Do it once and - # cache the result. - indices = _CACHED_INDICES.get(n) - if indices is None: - rng = npr.RandomState(42) - _CACHED_INDICES[n] = indices = rng.permutation(n) - return [xs[i] for i in indices[:k]] - def cases_from_gens(*gens): sizes = [1, 3, 10] cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1 @@ -703,14 +689,20 @@ def named_cases_from_sampler(gen): yield case +# Random sampling for every parameterized test is expensive. Do it once and +# cache the result. +@functools.lru_cache(maxsize=None) +def _choice(n, m): + rng = np.random.RandomState(42) + return rng.choice(n, size=m, replace=False) + def sample_product_testcases(*args, **kw): """Non-decorator form of sample_product.""" args = [list(arg) for arg in args] kw = [(k, list(v)) for k, v in kw.items()] n = prod(len(a) for a in args) * prod(len(v) for _, v in kw) - rng = np.random.RandomState(42) testcases = [] - for i in rng.choice(n, size=min(n, FLAGS.num_generated_cases), replace=False): + for i in _choice(n, min(n, FLAGS.num_generated_cases)): testcase = {} for a in args: testcase.update(a[i % len(a)]) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 81bb5f5b5..e552c7bd8 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -21,7 +21,6 @@ import unittest from absl import logging from absl.testing import absltest -from absl.testing import parameterized import jax from jax import ad_checkpoint @@ -200,12 +199,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertIsNotNone(f_tf.get_concrete_function()) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"_dtype={dtype.__name__}_function={with_function}", - dtype=dtype, - with_function=with_function) - for dtype in [np.int64, np.float64] - for with_function in [True, False])) + @jtu.sample_product( + dtype=[np.int64, np.float64], + with_function=[True, False], + ) def test_converts_64bit(self, dtype=np.int64, with_function=False): if not config.jax_enable_x64: self.skipTest("requires x64 mode") @@ -251,10 +248,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) self.ConvertAndCompare(f_jax, 0.7) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_gradients_disabled(self, with_function=False): f_tf = jax2tf.convert(jnp.tan, with_gradient=False) if with_function: @@ -270,10 +264,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): y = f_tf(x) _ = tape.gradient(y, x) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_gradients(self, with_function=True): def f(x, y): return x * x, x * y @@ -291,10 +282,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(5., tape.gradient(v, x)) self.assertAllClose(4., tape.gradient(v, y)) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_gradients_pytree(self, with_function=True): def f(xy: Tuple[float, float]) -> Dict[str, float]: x, y = xy @@ -368,10 +356,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(grad_jax.b, grad_tf[1]) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_gradients_with_ordered_dict_input(self, with_function=True): def f(inputs): out = 0.0 @@ -394,10 +379,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy()) self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy()) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_gradients_with_custom_jvp(self, with_function=True): """Check gradients, for a function with custom JVP.""" @jax.custom_jvp @@ -428,10 +410,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(4. * 4., y) self.assertAllClose(3. * 4., tape.gradient(y, x)) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_gradients_with_custom_vjp(self, with_function=True): """Check gradients, for a function with custom VJP.""" @jax.custom_vjp @@ -498,10 +477,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.int32), d_dx_tf.numpy()) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_gradients_unused_argument_readme(self, with_function=False): # x1 and x3 are not used. x3 has integer type. def fn(x0, x1, x2, x3): @@ -546,10 +522,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.)) self.assertAllClose(g_jax2tf[3].numpy(), np.int32(0)) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_gradients_int_argument(self, with_function=False): # https://github.com/google/jax/issues/6975 # Also issue #6975. @@ -875,9 +848,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): jax2tf.convert(outer)(2.) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"_{transform}", transform=transform) - for transform in ["jit", "jvp", "grad", "vmap"])) + @jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"]) def test_convert_under_transform_error(self, transform="vmap"): def outer(y): return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args @@ -886,9 +857,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): ValueError, "convert must be used outside all JAX transformations"): self.TransformConvertAndCompare(outer, np.ones((4,)), transform) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"_{transform}", transform=transform) - for transform in ["jit", "jvp", "grad", "vmap"])) + @jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"]) def test_convert_under_transform_error_non_tracer(self, transform="vmap"): def outer(y): sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg @@ -997,10 +966,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(tf_fn(tf.constant(1.375, tf.bfloat16)).numpy(), jnp.bfloat16(2.750)) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_kwargs(self, with_function=True): # Re: https://github.com/google/jax/issues/6791 def f_jax(*, x): @@ -1012,10 +978,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): f_tf(x=np.zeros(3, dtype=np.float32)), # Call with kwargs. np.zeros((), dtype=np.float32)) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_grad_kwargs(self, with_function=False): # Re: https://github.com/google/jax/issues/6791 x = (np.zeros(3, dtype=np.float32), diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 20d1e01ef..408b420a3 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -294,29 +294,19 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): f_jax = jax.jit(lambda i: params[i]) self.ConvertAndCompare(f_jax, indices) - @parameterized.named_parameters( - jtu.cases_from_list( - dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) - for f_jax in REDUCE)) + @jtu.sample_product(f_jax=REDUCE) def test_reduce_ops_with_numerical_input(self, f_jax): values = np.array([1, 2, 3], dtype=np.float32) self.ConvertAndCompare(f_jax, values) - @parameterized.named_parameters( - jtu.cases_from_list( - dict(testcase_name=f"_{op}", op=op) for op in ( - "add", "max", "min", "multiply", "set" - ))) + @jtu.sample_product(op=["add", "max", "min", "multiply", "set"]) def test_scatter_static(self, op): values = np.ones((5, 6), dtype=np.float32) update = np.float32(6.) f_jax = jax.jit(lambda v, u: getattr(v.at[::2, 3:], op)(u)) self.ConvertAndCompare(f_jax, values, update) - @parameterized.named_parameters( - jtu.cases_from_list( - dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) - for f_jax in REDUCE)) + @jtu.sample_product(f_jax=REDUCE) def test_reduce_ops_with_boolean_input(self, f_jax): values = np.array([True, False, True], dtype=np.bool_) self.ConvertAndCompare(f_jax, values) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index abf3df62a..d5c8ee739 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -834,10 +834,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): res_jax, jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y)) - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"function={with_function}", - with_function=with_function) - for with_function in [False, True])) + @jtu.sample_product(with_function=[False, True]) def test_grad_int(self, with_function=True): # https://github.com/google/jax/issues/7093 # Also issue #6975. diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 5ef8a8fc6..cc2f3a6c0 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -14,7 +14,7 @@ import unittest -from absl.testing import absltest, parameterized +from absl.testing import absltest import jax from jax.config import config @@ -67,16 +67,12 @@ class DLPackTest(jtu.JaxTestCase): if jtu.device_under_test() == "tpu": self.skipTest("DLPack not supported on TPU") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_take_ownership={}_gpu={}".format( - jtu.format_shape_dtype_string(shape, dtype), - take_ownership, gpu), - "shape": shape, "dtype": dtype, "take_ownership": take_ownership, - "gpu": gpu} - for shape in all_shapes - for dtype in dlpack_dtypes - for take_ownership in [False, True] - for gpu in [False, True])) + @jtu.sample_product( + shape=all_shapes, + dtype=dlpack_dtypes, + take_ownership=[False, True], + gpu=[False, True], + ) @jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973 def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu): rng = jtu.rand_default(self.rng()) @@ -95,12 +91,10 @@ class DLPackTest(jtu.JaxTestCase): "DLPack tensor may be consumed at most once", lambda: jax.dlpack.from_dlpack(dlpack)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes - for dtype in dlpack_dtypes)) + @jtu.sample_product( + shape=all_shapes, + dtype=dlpack_dtypes, + ) @unittest.skipIf(not tf, "Test requires TensorFlow") @jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973 def testTensorFlowToJax(self, shape, dtype): @@ -121,12 +115,10 @@ class DLPackTest(jtu.JaxTestCase): y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes - for dtype in dlpack_dtypes)) + @jtu.sample_product( + shape=all_shapes, + dtype=dlpack_dtypes, + ) @unittest.skipIf(not tf, "Test requires TensorFlow") def testJaxToTensorFlow(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, @@ -145,12 +137,10 @@ class DLPackTest(jtu.JaxTestCase): y = tf.experimental.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y.numpy()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes - for dtype in torch_dtypes)) + @jtu.sample_product( + shape=all_shapes, + dtype=torch_dtypes, + ) @unittest.skipIf(not torch, "Test requires PyTorch") def testTorchToJax(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]: @@ -177,12 +167,10 @@ class DLPackTest(jtu.JaxTestCase): xla_client._xla.dlpack_managed_tensor_to_buffer( y, client) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes - for dtype in torch_dtypes)) + @jtu.sample_product( + shape=all_shapes, + dtype=torch_dtypes, + ) @unittest.skipIf(not torch, "Test requires PyTorch") def testJaxToTorch(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]: @@ -194,12 +182,10 @@ class DLPackTest(jtu.JaxTestCase): y = torch.utils.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y.cpu().numpy()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes - for dtype in torch_dtypes)) + @jtu.sample_product( + shape=all_shapes, + dtype=torch_dtypes, + ) @unittest.skipIf(numpy_version < (1, 22, 0), "Requires numpy 1.22 or newer") @jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973 def testNumpyToJax(self, shape, dtype): @@ -208,12 +194,10 @@ class DLPackTest(jtu.JaxTestCase): x_jax = jnp.from_dlpack(x_np) self.assertAllClose(x_np, x_jax) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes - for dtype in torch_dtypes)) + @jtu.sample_product( + shape=all_shapes, + dtype=torch_dtypes, + ) @unittest.skipIf(numpy_version < (1, 22, 0), "Requires numpy 1.22 or newer") @jtu.skip_on_devices("gpu") def testJaxToNumpy(self, shape, dtype): @@ -230,12 +214,10 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase): if jtu.device_under_test() != "gpu": self.skipTest("__cuda_array_interface__ is only supported on GPU") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes - for dtype in dlpack_dtypes)) + @jtu.sample_product( + shape=all_shapes, + dtype=dlpack_dtypes, + ) @unittest.skipIf(not cupy, "Test requires CuPy") def testJaxToCuPy(self, shape, dtype): if dtype == jnp.bfloat16: diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 179a9a72f..522ff0a49 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -24,7 +24,6 @@ import unittest from unittest import skip, SkipTest from absl.testing import absltest -from absl.testing import parameterized import jax from jax import ad_checkpoint @@ -514,12 +513,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertEqual( len(local_devices()), len(re.findall(r"112", testing_stream.output))) - @parameterized.named_parameters( - jtu.cases_from_list( - dict( - testcase_name=f"_with_jit_{with_jit}", - with_jit=with_jit) - for with_jit in [True, False])) + @jtu.sample_product(with_jit=[True, False]) def test_tap_pytree(self, with_jit=False): def func(x, what=""): """Returns some pytrees depending on x""" @@ -551,12 +545,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() # Wait for receivers to be done self.assertEqual(3, tap_count) - @parameterized.named_parameters( - jtu.cases_from_list( - dict( - testcase_name=f"_concurrent_{concurrent}", - concurrent=concurrent) - for concurrent in [True, False])) + @jtu.sample_product(concurrent=[True, False]) def test_tap_multiple(self, concurrent=False): """Call id_tap multiple times, concurrently or in sequence. """ if concurrent and jtu.device_under_test() in ["cpu", "gpu"]: @@ -628,12 +617,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): [t.start() for t in threads] [t.join() for t in threads] - @parameterized.named_parameters( - jtu.cases_from_list( - dict( - testcase_name=f"_with_jit_{with_jit}", - with_jit=with_jit) - for with_jit in [True, False])) + @jtu.sample_product(with_jit=[True, False]) def test_tap_cond(self, with_jit=False): """A conditional""" @@ -663,11 +647,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): where: end 4""", testing_stream.output) - @parameterized.named_parameters( - jtu.cases_from_list( - dict(testcase_name=f"_with_jit_{with_jit}", - with_jit=with_jit) - for with_jit in [True, False])) + @jtu.sample_product(with_jit=[True, False]) def test_tap_while_cond(self, with_jit=False): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) @@ -741,12 +721,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): where: 3 3""", testing_stream.output) - @parameterized.named_parameters( - jtu.cases_from_list( - dict( - testcase_name=f"_with_jit_{with_jit}", - with_jit=with_jit) - for with_jit in [True, False])) + @jtu.sample_product(with_jit=[True, False]) def test_tap_scan_cond(self, with_jit=True): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) @@ -796,15 +771,11 @@ class HostCallbackTapTest(jtu.JaxTestCase): [1 2 3]""", testing_stream.output) testing_stream.reset() - @parameterized.named_parameters( - jtu.cases_from_list( - dict( - testcase_name=f"_shape_{shape}_dtype_{np.dtype(dtype).name}_nr_args={nr_args}", - shape=shape, - dtype=dtype, - nr_args=nr_args) for nr_args in [1, 2] - for shape in [(), (2,), (2, 3), (2, 3, 4)] - for dtype in jtu.dtypes.all)) + @jtu.sample_product( + nr_args=[1, 2], + shape=[(), (2,), (2, 3), (2, 3, 4)], + dtype=jtu.dtypes.all, + ) def test_tap_jit_dtypes(self, nr_args=2, dtype=jnp.int16, shape=(2,)): if dtype in (jnp.complex64, jnp.complex128, jnp.bool_): raise SkipTest(f"host_callback not implemented for {dtype}.") @@ -1971,13 +1942,11 @@ class HostCallbackTapTest(jtu.JaxTestCase): 10""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) - @parameterized.named_parameters( - jtu.cases_from_list( - dict(testcase_name=f"_use_remat={use_remat}_{grad_func}_use_result={use_result}", - use_result=use_result, use_remat=use_remat, grad_func=grad_func) - for use_result in [True, False] - for grad_func in ["grad", "value_and_grad"] - for use_remat in ["old", "new", "none"])) + @jtu.sample_product( + use_result=[True, False], + grad_func=["grad", "value_and_grad"], + use_remat=["old", "new", "none"], + ) def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"): if use_remat == "old": raise SkipTest() @@ -2108,11 +2077,9 @@ class HostCallbackCallTest(jtu.JaxTestCase): self.assertAllClose(2 * arg, fun(arg)) self.assertEqual(count[0], 1) - @parameterized.named_parameters( - jtu.cases_from_list( - dict(testcase_name=f"_{np.dtype(dtype).name}", dtype=dtype) - for dtype in jtu.dtypes.all - if dtype != np.bool_)) + @jtu.sample_product( + dtype=[dtype for dtype in jtu.dtypes.all if dtype != np.bool_], + ) def test_call_types(self, dtype=np.float64): def f_outside(x): diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index a6ad232fc..fcf2c3739 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -133,7 +133,7 @@ LAX_GRAD_OPS = [ grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_default, dtypes=grad_complex_dtypes), grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_float_dtypes, tol={np.float64: 3e-5}), + dtypes=grad_float_dtypes, tol={np.float64: 5e-3}), grad_test_spec(lax.logistic, nargs=1, order=2, rng_factory=jtu.rand_default, dtypes=grad_inexact_dtypes), @@ -195,38 +195,43 @@ def check_grads_bilinear(f, args, order, class LaxAutodiffTest(jtu.JaxTestCase): - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.name, shapes, itertools.repeat(dtype)), - "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, - "order": rec.order, "tol": rec.tol} - for shape_group in compatible_shapes + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(op=rec.op, rng_factory=rec.rng_factory, order=rec.order, tol=rec.tol)], + shapes=[ + shapes for shape_group in compatible_shapes for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs) - for dtype in rec.dtypes) - for rec in LAX_GRAD_OPS)) + ], + dtype=rec.dtypes, + ) + for rec in LAX_GRAD_OPS + )) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) - if jtu.device_under_test() == "tpu" and op is lax.pow: - raise SkipTest("pow grad imprecise on tpu") + if jtu.device_under_test() == "tpu": + if op is lax.pow: + raise SkipTest("pow grad imprecise on tpu") + if op is lax.cos: + order = 1 # 2nd-order gradient is imprecise on TPU. + tol = jtu.join_tolerance(1e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol args = tuple(rng(shape, dtype) for shape in shapes) check_grads(op, args, order, ["fwd", "rev"], tol, tol) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": f"_{rec.op.__name__}_{special_value}", - "op": rec.op, "special_value": special_value, "tol": rec.tol} - for special_value in rec.values) - for rec in LAX_GRAD_SPECIAL_VALUE_TESTS)) + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(op=rec.op, tol=rec.tol)], + special_value=rec.values, + ) + for rec in LAX_GRAD_SPECIAL_VALUE_TESTS + )) def testOpGradSpecialValue(self, op, special_value, tol): check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_from_dtype={}_to_dtype={}".format( - jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)), - "from_dtype": from_dtype, "to_dtype": to_dtype} - for from_dtype, to_dtype in itertools.product(inexact_dtypes, repeat=2))) + @jtu.sample_product( + from_dtype=inexact_dtypes, + to_dtype=inexact_dtypes, + ) def testConvertElementTypeGrad(self, from_dtype, to_dtype): rng = jtu.rand_default(self.rng()) tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), @@ -237,12 +242,10 @@ class LaxAutodiffTest(jtu.JaxTestCase): convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in [(), (2, 3)] - for dtype in grad_float_dtypes)) + @jtu.sample_product( + shape=[(), (2, 3)], + dtype=grad_float_dtypes, + ) def testClampGrad(self, shape, dtype): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) @@ -253,15 +256,14 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format( - dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name, - num_arrs), - "dim": dim, "base_shape": base_shape, "dtype": dtype, "num_arrs": num_arrs} - for num_arrs in [3] - for dtype in float_dtypes + @jtu.sample_product( + [dict(base_shape=base_shape, dim=dim) for base_shape in [(4,), (3, 4), (2, 3, 4)] - for dim in range(len(base_shape)))) + for dim in range(len(base_shape)) + ], + num_arrs=[3], + dtype=float_dtypes, + ) def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs): rng = jtu.rand_default(self.rng()) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] @@ -270,21 +272,17 @@ class LaxAutodiffTest(jtu.JaxTestCase): concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - strides, padding), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "strides": strides, "padding": padding} + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides) for lhs_shape, rhs_shape, all_strides in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)])]) for strides in all_strides - for dtype in float_dtypes - for padding in ["VALID", "SAME"])) + ], + dtype=float_dtypes, + padding=["VALID", "SAME"], + ) def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding): rng = jtu.rand_small(self.rng()) lhs = rng(lhs_shape, dtype) @@ -294,18 +292,11 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" - "rhs_dilation={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - strides, padding, lhs_dil, rhs_dil), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "strides": strides, "padding": padding, "lhs_dil": lhs_dil, - "rhs_dil": rhs_dil} + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides, + padding=padding, lhs_dil=lhs_dil, rhs_dil=rhs_dil) for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in - itertools.chain( + itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)], [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))], [(1, 1), (2, 1)], [(1, 1)]) @@ -315,8 +306,10 @@ class LaxAutodiffTest(jtu.JaxTestCase): for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils - for dtype in float_dtypes - for padding in all_pads)) + for padding in all_pads + ], + dtype=float_dtypes, + ) def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil): rng = jtu.rand_small(self.rng()) @@ -328,19 +321,11 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" - "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), - feature_group_count, batch_group_count), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "strides": strides, "padding": padding, "lhs_dil": lhs_dil, - "rhs_dil": rhs_dil, "dimension_numbers": dim_nums, - "perms": perms, "feature_group_count": feature_group_count, - "batch_group_count": batch_group_count} + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides, + padding=padding, lhs_dil=lhs_dil, rhs_dil=rhs_dil, + feature_group_count=feature_group_count, + batch_group_count=batch_group_count) for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [ ([(b * batch_group_count, i * feature_group_count, 6, 7), @@ -354,13 +339,17 @@ class LaxAutodiffTest(jtu.JaxTestCase): for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils - for dtype in grad_inexact_dtypes for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] + - ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else [])) + ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else [])) + ], + [dict(dimension_numbers=dim_nums, perms=perms) for dim_nums, perms in [ - (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), - (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), - (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))])) + (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), + (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), + (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] + ], + dtype=grad_inexact_dtypes, + ) def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count): @@ -386,13 +375,11 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_lhs_shape={}_rhs_shape={}".format( - jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype)), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype} - for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)] - for dtype in float_dtypes)) + @jtu.sample_product( + lhs_shape=[(2,), (3, 2)], + rhs_shape=[(2,), (2, 4)], + dtype=float_dtypes, + ) def testDotGrad(self, lhs_shape, rhs_shape, dtype): rng = jtu.rand_default(self.rng()) tol = {np.float16: 1e-1, np.float32: 1e-4} @@ -407,14 +394,9 @@ class LaxAutodiffTest(jtu.JaxTestCase): s = str(jax.make_jaxpr(pullback)(gresult)) assert "Precision.HIGHEST" in s - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_dimension_numbers={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - dimension_numbers), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "dimension_numbers": dimension_numbers} + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, + dimension_numbers=dimension_numbers) for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 2), (2, 4), (([1], [0]), ([], []))), ((3, 5), (2, 5), (([1], [1]), ([], []))), @@ -423,7 +405,9 @@ class LaxAutodiffTest(jtu.JaxTestCase): ((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))), ((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))), ] - for dtype in float_dtypes)) + ], + dtype=float_dtypes, + ) def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers): rng = jtu.rand_small(self.rng()) @@ -438,46 +422,36 @@ class LaxAutodiffTest(jtu.JaxTestCase): s = str(jax.make_jaxpr(pullback)(gresult)) assert "Precision.HIGHEST" in s - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format( - shape, np.dtype(dtype).name, broadcast_sizes), - "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes} - for shape in [(), (2, 3)] - for dtype in float_dtypes - for broadcast_sizes in [(), (2,), (1, 2)])) + @jtu.sample_product( + shape=[(), (2, 3)], + dtype=float_dtypes, + broadcast_sizes=[(), (2,), (1, 2)], + ) def testBroadcastGrad(self, shape, dtype, broadcast_sizes): rng = jtu.rand_default(self.rng()) args = (rng(shape, dtype),) broadcast = lambda x: lax.broadcast(x, broadcast_sizes) check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( - jtu.format_shape_dtype_string(inshape, dtype), - outshape, broadcast_dimensions), - "inshape": inshape, "dtype": dtype, "outshape": outshape, - "dimensions": broadcast_dimensions} + @jtu.sample_product( + [dict(inshape=inshape, outshape=outshape, dimensions=broadcast_dimensions) for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] - for dtype in float_dtypes)) + ], + dtype=float_dtypes, + ) def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions): rng = jtu.rand_default(self.rng()) operand = rng(inshape, dtype) broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}_perm={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype), - permutation), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, - "permutation": permutation} - for dtype in float_dtypes + @jtu.sample_product( + [dict(arg_shape=arg_shape, out_shape=out_shape, permutation=permutation) for arg_shape, out_shape, permutation in [ [(3, 4), (12,), None], [(2, 1, 4), (8,), None], @@ -488,23 +462,26 @@ class LaxAutodiffTest(jtu.JaxTestCase): [(2, 1, 4), (8,), (2, 0, 1)], [(2, 2, 4), (2, 8), (0, 2, 1)], [(2, 2, 4), (2, 8), (2, 0, 1)], - ])) + ] + ], + dtype=float_dtypes, + ) def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype): rng = jtu.rand_default(self.rng()) operand = rng(arg_shape, dtype) reshape = lambda x: lax.reshape(x, out_shape, permutation) check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_pads={}" - .format(jtu.format_shape_dtype_string(shape, dtype), pads), - "shape": shape, "dtype": dtype, "pads": pads} - for dtype in float_dtypes + @jtu.sample_product( + [dict(shape=shape, pads=pads) for shape, paddings in [ [(), [()]], ((2, 3), [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]), ] - for pads in paddings)) + for pads in paddings + ], + dtype=float_dtypes, + ) def testPadGrad(self, shape, dtype, pads): rng = jtu.rand_small(self.rng()) operand = rng(shape, dtype) @@ -546,14 +523,13 @@ class LaxAutodiffTest(jtu.JaxTestCase): # self.assertEqual(result, 0.0) self.assertAllClose(result, np.nan) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_predshape={}_argshapes={}".format( - jtu.format_shape_dtype_string(pred_shape, np.bool_), - jtu.format_shape_dtype_string(arg_shape, dtype)), - "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype} + @jtu.sample_product( + [dict(arg_shape=arg_shape, pred_shape=pred_shape) for arg_shape in [(), (3,), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) - for dtype in float_dtypes)) + ], + dtype=float_dtypes, + ) def testSelectGrad(self, pred_shape, arg_shape, dtype): rng = jtu.rand_default(self.rng()) pred = rng(pred_shape, np.bool_) @@ -562,13 +538,9 @@ class LaxAutodiffTest(jtu.JaxTestCase): select = lambda on_true, on_false: lax.select(pred, on_true, on_false) check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_start_indices={}_limit_indices={}_strides={}".format( - jtu.format_shape_dtype_string(shape, dtype), - start_indices, limit_indices, strides), - "shape": shape, "dtype": dtype, "starts": start_indices, - "limits": limit_indices, "strides": strides} + @jtu.sample_product( + [dict(shape=shape, starts=start_indices, limits=limit_indices, + strides=strides) for shape, start_indices, limit_indices, strides in [ [(3,), (1,), (2,), None], [(7,), (4,), (7,), None], @@ -581,44 +553,43 @@ class LaxAutodiffTest(jtu.JaxTestCase): [(5, 3), (1, 1), (5, 3), (2, 1)], [(3, 3, 5), (0, 2, 0), (3, 2, 5), (1, 2, 1)] ] - for dtype in float_dtypes)) + ], + dtype=float_dtypes, + ) def testSliceGrad(self, shape, dtype, starts, limits, strides): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) slice = lambda x: lax.slice(x, starts, limits, strides) check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( - jtu.format_shape_dtype_string(shape, dtype), - start_indices, size_indices), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "size_indices": size_indices} + @jtu.sample_product( + [dict(shape=shape, start_indices=start_indices, size_indices=size_indices) for shape, start_indices, size_indices in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] - for dtype in float_dtypes)) + ], + dtype=float_dtypes, + ) def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype), - start_indices, update_shape), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "update_shape": update_shape} + @jtu.sample_product( + [dict(shape=shape, start_indices=start_indices, update_shape=update_shape) for shape, start_indices, update_shape in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] - for dtype in float_dtypes)) - def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape): + ], + dtype=float_dtypes, + ) + def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, + update_shape): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) update = rng(update_shape, dtype) @@ -664,28 +635,25 @@ class LaxAutodiffTest(jtu.JaxTestCase): result2, _ = jax.value_and_grad(f, 0)(x, y) self.assertAllClose(result1, result2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_perm={}".format( - jtu.format_shape_dtype_string(shape, dtype), perm), - "shape": shape, "dtype": dtype, "perm": perm} + @jtu.sample_product( + [dict(shape=shape, perm=perm) for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] - for dtype in float_dtypes)) + ], + dtype=float_dtypes, + ) def testTransposeGrad(self, shape, dtype, perm): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) transpose = lambda x: lax.transpose(x, perm) check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_inshape={}_reducedims={}" - .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims), - "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, - "dims": dims, "rng_factory": rng_factory} + @jtu.sample_product( + [dict(init_val=init_val, op=op, dtype=dtype, rng_factory=rng_factory) for init_val, op, dtypes, rng_factory in [ (0, lax.add, float_dtypes + jtu.dtypes.complex, jtu.rand_default), (-np.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int), @@ -693,6 +661,8 @@ class LaxAutodiffTest(jtu.JaxTestCase): (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)), ] for dtype in dtypes + ], + [dict(shape=shape, dims=dims) for shape, dims in [ [(), ()], [(3, 4, 5), ()], @@ -702,7 +672,9 @@ class LaxAutodiffTest(jtu.JaxTestCase): [(3, 4, 5), (0, 1, 2)], [(3, 1), (1,)], [(3, 0, 5), (1,)], - ])) + ] + ], + ) def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.mul: @@ -718,11 +690,8 @@ class LaxAutodiffTest(jtu.JaxTestCase): if op not in (lax.max, lax.min) or all(d > 0 for d in shape): check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_reducedims={}" - .format(jtu.format_shape_dtype_string(shape, dtype), dims), - "shape": shape, "dtype": dtype, "dims": dims} - for dtype in grad_float_dtypes + @jtu.sample_product( + [dict(shape=shape, dims=dims) for shape, dims in [ [(3, 4, 5), ()], [(3, 4, 5), (0,)], @@ -731,7 +700,10 @@ class LaxAutodiffTest(jtu.JaxTestCase): [(3, 4, 5), (0, 1, 2)], [(3, 1), (1,)], [(3, 0, 5), (1,)], - ])) + ] + ], + dtype=grad_float_dtypes, + ) def testReducePairGrad(self, shape, dtype, dims): rng = jtu.rand_default(self.rng(), scale=1) tol = {np.float32: 1e-2, np.float64: 1e-4} @@ -742,20 +714,16 @@ class LaxAutodiffTest(jtu.JaxTestCase): reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims) check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" - "_basedilation={}_windowdilation={}") - .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, - strides, padding, base_dilation, window_dilation), - "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, - "dims": dims, "strides": strides, "padding": padding, - "base_dilation": base_dilation, "window_dilation": window_dilation, - "rng_factory": rng_factory} + @jtu.sample_product( + [dict(init_val=init_val, op=op, dtype=dtype, rng_factory=rng_factory, + shape=shape, dims=dims, strides=strides, padding=padding, + base_dilation=base_dilation, window_dilation=window_dilation) for init_val, op, dtypes, rng_factory in [ (0, lax.add, grad_float_dtypes, jtu.rand_small), (-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int), (np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int), ] + for dtype in dtypes for shape, dims, strides, padding, base_dilation, window_dilation in ( itertools.chain( itertools.product( @@ -772,7 +740,8 @@ class LaxAutodiffTest(jtu.JaxTestCase): ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [(1, 1, 1, 1)] + ([(2, 1, 3, 2)]), [(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else [])))) - for dtype in dtypes)) + ], + ) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") def testReduceWindowGrad( @@ -808,20 +777,20 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, eps) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_shape={}_axis={}_reverse={}" - .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, - reverse), - "op": op, "shape": shape, "dtype": dtype, - "axis": axis, "reverse": reverse} + @jtu.sample_product( + [dict(op=op, dtype=dtype) for op, types in [ (lax.cumsum, [np.float32, np.float64]), (lax.cumprod, [np.float32, np.float64]), ] for dtype in types + ], + [dict(shape=shape, axis=axis) for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) - for reverse in [False, True])) + ], + reverse=[False, True], + ) def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer) else jtu.rand_small) @@ -831,33 +800,30 @@ class LaxAutodiffTest(jtu.JaxTestCase): # TODO(b/205052657): enable more tests when supported - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}_isstable={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, is_stable), - "shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable} - for dtype in [np.float32] + @jtu.sample_product( + [dict(shape=shape, axis=axis) for shape in [(5,), (5, 7)] for axis in [len(shape) - 1] - for is_stable in [False, True])) + ], + dtype=[np.float32], + is_stable=[False, True], + ) def testSortGrad(self, shape, dtype, axis, is_stable): - rng = jtu.rand_default(self.rng()) + rng = jtu.rand_unique_int(self.rng()) operand = rng(shape, dtype) sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable) check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2) # TODO(b/205052657): enable more tests when supported - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format( - jtu.format_shape_dtype_string(shape, key_dtype), - jtu.format_shape_dtype_string(shape, val_dtype), - axis, is_stable), - "shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype, - "axis": axis, "is_stable": is_stable} - for key_dtype in [np.float32] - for val_dtype in [np.float32] + @jtu.sample_product( + [dict(shape=shape, axis=axis) for shape in [(3,), (5, 3)] for axis in [len(shape) - 1] - for is_stable in [False, True])) + ], + key_dtype=[np.float32], + val_dtype=[np.float32], + is_stable=[False, True], + ) def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable): rng = jtu.rand_default(self.rng()) # This test relies on the property that wherever keys are tied, values are @@ -873,30 +839,28 @@ class LaxAutodiffTest(jtu.JaxTestCase): fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable) check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_k={}".format( - jtu.format_shape_dtype_string(shape, dtype), k), - "shape": shape, "dtype": dtype, "k": k} - for dtype in [np.float32,] - for shape in [(4,), (5, 5), (2, 1, 4)] - for k in [1, 3])) + @jtu.sample_product( + dtype=[np.float32,], + shape=[(4,), (5, 5), (2, 1, 4)], + k=[1, 3], + ) def testTopKGrad(self, shape, dtype, k): flat_values = np.arange(prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_axes={}".format( - jtu.format_shape_dtype_string(shape, dtype), idxs, axes), - "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes} - for dtype in float_dtypes + @jtu.sample_product( + [dict(shape=shape, idxs=idxs, axes=axes) for shape, idxs, axes in [ [(3, 4, 5), (np.array([0, 2, 1]),), (0,)], [(3, 4, 5), (np.array([-1, -2]),), (0,)], [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)], [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)], - ])) + ] + ], + dtype=float_dtypes, + ) @jax.numpy_rank_promotion('allow') # Test explicitly exercises implicit rank promotion. def testIndexTakeGrad(self, shape, dtype, idxs, axes): rng = jtu.rand_default(self.rng()) @@ -904,16 +868,9 @@ class LaxAutodiffTest(jtu.JaxTestCase): index_take = lambda src: lax.index_take(src, idxs, axes) check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" - f"_idxs={jtu.format_shape_dtype_string(idxs.shape, idxs.dtype)}" - f"_dnums={dnums}_slice_sizes={slice_sizes}_mode={mode}" - f"_iteration={iteration}", - "shape": shape, "dtype": dtype, "idxs_shape": idxs.shape, - "idxs_dtype": idxs.dtype, "dnums": dnums, "slice_sizes": slice_sizes, - "max_idx": max_idx, "mode": mode} - for dtype in grad_float_dtypes + @jtu.sample_product( + [dict(shape=shape, idxs_shape=idxs.shape, idxs_dtype=idxs.dtype, + dnums=dnums, slice_sizes=slice_sizes, max_idx=max_idx) for shape, idxs, dnums, slice_sizes, max_idx in [ ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), @@ -925,10 +882,13 @@ class LaxAutodiffTest(jtu.JaxTestCase): offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3), 3), ] - for mode in ["clip", "fill", "promise_in_bounds"] - for iteration in range(5))) + ], + dtype=grad_float_dtypes, + mode=["clip", "fill", "promise_in_bounds"], + iteration=range(5), + ) def testGatherGrad(self, shape, dtype, idxs_shape, idxs_dtype, dnums, - slice_sizes, mode, max_idx): + slice_sizes, mode, max_idx, iteration): rng = jtu.rand_default(self.rng()) if mode == "promise_in_bounds": rng_idx = jtu.rand_int(self.rng(), high=max_idx) @@ -945,16 +905,9 @@ class LaxAutodiffTest(jtu.JaxTestCase): x = rng(shape, dtype) check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_shape={jtu.format_shape_dtype_string(arg_shape, dtype)}" - f"_idxs={jtu.format_shape_dtype_string(idxs.shape, idxs.dtype)}" - f"_update={update_shape}_dnums={dnums}_mode={mode}" - f"_iteration={iteration}", - "arg_shape": arg_shape, "dtype": dtype, "idxs_shape": idxs.shape, - "idxs_dtype": idxs.dtype, "update_shape": update_shape, "dnums": dnums, - "max_idx": max_idx, "mode": mode} - for dtype in grad_float_dtypes + @jtu.sample_product( + [dict(arg_shape=arg_shape, idxs_shape=idxs.shape, idxs_dtype=idxs.dtype, + dnums=dnums, update_shape=update_shape, max_idx=max_idx) for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(update_window_dims=(), @@ -969,10 +922,13 @@ class LaxAutodiffTest(jtu.JaxTestCase): inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] - for mode in ["clip", "fill", "promise_in_bounds"] - for iteration in range(5))) + ], + dtype=grad_float_dtypes, + mode=["clip", "fill", "promise_in_bounds"], + iteration=range(5), + ) def testScatterAddGrad(self, arg_shape, dtype, idxs_shape, idxs_dtype, - update_shape, dnums, max_idx, mode): + update_shape, dnums, max_idx, mode, iteration): rng = jtu.rand_default(self.rng()) if mode == "promise_in_bounds": rng_idx = jtu.rand_int(self.rng(), high=max_idx) @@ -987,13 +943,9 @@ class LaxAutodiffTest(jtu.JaxTestCase): y = rng(update_shape, dtype) check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - idxs, update_shape, dnums), - "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, - "update_shape": update_shape, "dnums": dnums, "rng_idx_factory": rng_idx_factory} - for dtype in grad_float_dtypes + @jtu.sample_product( + [dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums, + update_shape=update_shape, max_idx=max_idx) for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), @@ -1005,13 +957,15 @@ class LaxAutodiffTest(jtu.JaxTestCase): update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] - # Scatters with conflicting indices are not deterministic on GPU, so we - # use indices that do not collide. - for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)])) + ], + dtype=grad_float_dtypes, + ) def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, - rng_idx_factory): + max_idx): + # Scatters with conflicting indices are not deterministic on GPU, so we + # use indices that do not collide. + rng_idx = jtu.rand_unique_int(self.rng(), high=max_idx) rng = jtu.rand_default(self.rng()) - rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) @@ -1028,13 +982,9 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - idxs, update_shape, dnums), - "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, - "update_shape": update_shape, "dnums": dnums} - for dtype in grad_float_dtypes + @jtu.sample_product( + [dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums, + update_shape=update_shape) for arg_shape, idxs, update_shape, dnums in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), @@ -1045,7 +995,10 @@ class LaxAutodiffTest(jtu.JaxTestCase): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), - ])) + ] + ], + dtype=grad_float_dtypes, + ) def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums): rng = jtu.rand_default(self.rng()) rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape)) @@ -1055,13 +1008,9 @@ class LaxAutodiffTest(jtu.JaxTestCase): y = rng(update_shape, dtype) check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - idxs, update_shape, dnums), - "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, - "update_shape": update_shape, "dnums": dnums} - for dtype in grad_float_dtypes + @jtu.sample_product( + [dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums, + update_shape=update_shape) for arg_shape, idxs, update_shape, dnums in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), @@ -1072,7 +1021,10 @@ class LaxAutodiffTest(jtu.JaxTestCase): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), - ])) + ] + ], + dtype=grad_float_dtypes, + ) def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums): rng = jtu.rand_default(self.rng()) rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape)) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index a5167b822..ca4e4cd6c 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -137,7 +137,7 @@ if numpy_version >= (1, 22): # where keyword added in numpy 1.22 op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], inexact=True), op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), + inexact=True, tolerance={np.float16: 3e-3}), op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], inexact=True), ] @@ -148,7 +148,7 @@ JAX_REDUCER_NO_DTYPE_RECORDS = [ op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), + inexact=True, tolerance={jnp.bfloat16: 2e-2}), op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], inexact=True), op_record("nanmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []), @@ -179,23 +179,25 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): for a in out] return f - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, - "None" if out_dtype is None else np.dtype(out_dtype).name, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)], + [dict(shape=shape, axis=axis, dtype=dtype) + for shape in rec.shapes + for dtype in rec.dtypes for axis in list(range(-len(shape), len(shape))) + [None] - for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_RECORDS)) - def testReducer(self, np_op, jnp_op, rng_factory, shape, dtype, out_dtype, + if jtu.is_valid_shape(shape, dtype) + ], + out_dtype=[out_dtype for out_dtype in [None] + rec.dtypes + if out_dtype not in unsigned_dtypes], + keepdims=[False, True], + ) + for rec in JAX_REDUCER_RECORDS + )) + def testReducer(self, name, rng_factory, shape, dtype, out_dtype, axis, keepdims, inexact): + np_op = getattr(np, name) + jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) @jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=RuntimeWarning, @@ -223,23 +225,26 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact, + tolerance=rec.tolerance)], + [dict(shape=shape, axis=axis, dtype=dtype) for shape in rec.shapes for dtype in rec.dtypes for axis in list(range(-len(shape), len(shape))) + [None] - for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_NO_DTYPE_RECORDS)) - def testReducerNoDtype(self, np_op, jnp_op, rng_factory, shape, dtype, axis, - keepdims, inexact): + if jtu.is_valid_shape(shape, dtype) + ], + keepdims=[False, True], + ) + for rec in JAX_REDUCER_NO_DTYPE_RECORDS + )) + def testReducerNoDtype(self, name, rng_factory, shape, dtype, axis, + keepdims, inexact, tolerance): + np_op = getattr(np, name) + jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + is_bf16_nan_test = (dtype == jnp.bfloat16 and + rng_factory.__name__ == 'rand_some_nan') @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") @jtu.ignore_warning(category=RuntimeWarning, @@ -255,25 +260,28 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] - tol = {np.float16: 0.002} + tol = jtu.join_tolerance({np.float16: 0.002}, + tolerance or jtu.default_tolerance()) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)], + [dict(shape=shape, axis=axis, dtype=dtype) for shape in rec.shapes for dtype in rec.dtypes for axis in list(range(-len(shape), len(shape))) + [None] - for initial in [0, 1] for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_INITIAL_RECORDS)) - def testReducerInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + if jtu.is_valid_shape(shape, dtype) + ], + initial=[0, 1], + keepdims=[False, True], + ) + for rec in JAX_REDUCER_INITIAL_RECORDS + )) + def testReducerInitial(self, name, rng_factory, shape, dtype, axis, keepdims, initial, inexact): + np_op = getattr(np, name) + jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, @@ -295,25 +303,27 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_promote_integers={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, promote_integers), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact, - "promote_integers": promote_integers} + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)], + [dict(shape=shape, axis=axis, dtype=dtype) for shape in rec.shapes for dtype in rec.dtypes for axis in list(range(-len(shape), len(shape))) + [None] - for initial in [0, 1] for keepdims in [False, True] - for promote_integers in [True, False] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_PROMOTE_INT_RECORDS)) - def testReducerPromoteInt(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + if jtu.is_valid_shape(shape, dtype) + ], + initial=[0, 1], + keepdims=[False, True], + promote_integers=[False, True], + ) + for rec in JAX_REDUCER_PROMOTE_INT_RECORDS + )) + def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis, keepdims, initial, inexact, promote_integers): + np_op = getattr(np, name) + jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + is_bf16_nan_test = (dtype == jnp.bfloat16 and + rng_factory.__name__ == 'rand_some_nan') @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") @jtu.ignore_warning(category=np.ComplexWarning) @@ -338,21 +348,22 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)], + [dict(shape=shape, axis=axis) for shape in rec.shapes if np.prod(shape) == 0 - for dtype in rec.dtypes - for keepdims in [False, True] - for axis in range(-len(shape), len(shape)) if shape[axis] >= 1) - for rec in JAX_REDUCER_INITIAL_RECORDS)) - def testReducerNoInitialZeroDims(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + for axis in range(-len(shape), len(shape)) if shape[axis] >= 1 + ], + dtype=rec.dtypes, + keepdims=[False, True], + ) + for rec in JAX_REDUCER_INITIAL_RECORDS + )) + def testReducerNoInitialZeroDims(self, name, rng_factory, shape, dtype, axis, keepdims, inexact): + np_op = getattr(np, name) + jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, @@ -374,23 +385,24 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, - jtu.format_shape_dtype_string(whereshape, bool)), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape, - "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)], + [dict(shape=shape, axis=axis, dtype=dtype, whereshape=whereshape) for shape in rec.shapes for dtype in rec.dtypes - for whereshape in _compatible_shapes(shape) for axis in list(range(-len(shape), len(shape))) + [None] - for initial in [0, 1] for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_INITIAL_RECORDS)) - def testReducerWhere(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + if jtu.is_valid_shape(shape, dtype) + for whereshape in _compatible_shapes(shape) + ], + initial=[0, 1], + keepdims=[False, True], + ) + for rec in JAX_REDUCER_INITIAL_RECORDS + )) + def testReducerWhere(self, name, rng_factory, shape, dtype, axis, keepdims, initial, inexact, whereshape): + np_op = getattr(np, name) + jnp_op = getattr(jnp, name) if (shape in [()] + scalar_shapes and dtype in [jnp.int16, jnp.uint16] and jnp_op in [jnp.min, jnp.max]): @@ -417,23 +429,23 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, - jtu.format_shape_dtype_string(whereshape, bool)), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape, - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for whereshape in _compatible_shapes(shape) - for axis in list(range(-len(shape), len(shape))) + [None] - for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)) - def testReducerWhereNoInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis, - keepdims, inexact, whereshape): + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact, + tol=rec.tolerance)], + [dict(shape=shape, axis=axis, dtype=dtype, whereshape=whereshape) + for shape in rec.shapes for dtype in rec.dtypes + for whereshape in _compatible_shapes(shape) + for axis in list(range(-len(shape), len(shape))) + [None] + if jtu.is_valid_shape(shape, dtype) + ], + keepdims=[False, True], + ) for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS + )) + def testReducerWhereNoInitial(self, name, rng_factory, shape, dtype, axis, + keepdims, inexact, whereshape, tol): + np_op = getattr(np, name) + jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) is_bf16_nan_test = dtype == jnp.bfloat16 # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. @@ -458,7 +470,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) def testReductionOfOutOfBoundsAxis(self): # Issue 888 @@ -510,19 +522,14 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, rtol=tol, atol=tol) - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" - .format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims), - "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, - "ddof": ddof, "keepdims": keepdims} - for shape in [(5,), (10, 5)] - for dtype in all_dtypes - for out_dtype in inexact_dtypes - for axis in [None, 0, -1] - for ddof in [0, 1, 2] - for keepdims in [False, True])) + @jtu.sample_product( + shape=[(5,), (10, 5)], + dtype=all_dtypes, + out_dtype=inexact_dtypes, + axis=[None, 0, -1], + ddof=[0, 1, 2], + keepdims=[False, True], + ) def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): rng = jtu.rand_default(self.rng()) args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @@ -547,19 +554,14 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol) - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" - .format(shape, dtype, out_dtype, axis, ddof, keepdims), - "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, - "ddof": ddof, "keepdims": keepdims} - for shape in [(5,), (10, 5)] - for dtype in all_dtypes - for out_dtype in inexact_dtypes - for axis in [None, 0, -1] - for ddof in [0, 1, 2] - for keepdims in [False, True])) + @jtu.sample_product( + shape=[(5,), (10, 5)], + dtype=all_dtypes, + out_dtype=inexact_dtypes, + axis=[None, 0, -1], + ddof=[0, 1, 2], + keepdims=[False, True], + ) def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): rng = jtu.rand_some_nan(self.rng()) args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @@ -594,22 +596,20 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): z = jax.grad(jnp.nanstd)(x) self.assertEqual(jnp.isnan(z).sum(), 0) - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format( - shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights), - "shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype,"rowvar": rowvar, "ddof": ddof, - "bias": bias, "fweights": fweights, "aweights": aweights} - for shape in [(5,), (10, 5), (5, 10)] - for dtype in all_dtypes - for y_dtype in [None, dtype] - for rowvar in [True, False] - for y_shape in _get_y_shapes(y_dtype, shape, rowvar) - for bias in [True, False] - for ddof in [None, 2, 3] - for fweights in [True, False] - for aweights in [True, False])) + @jtu.sample_product( + [dict(shape=shape, dtype=dtype, y_dtype=y_dtype, rowvar=rowvar, + y_shape=y_shape) + for shape in [(5,), (10, 5), (5, 10)] + for dtype in all_dtypes + for y_dtype in [None, dtype] + for rowvar in [True, False] + for y_shape in _get_y_shapes(y_dtype, shape, rowvar) + ], + bias=[True, False], + ddof=[None, 2, 3], + fweights=[True, False], + aweights=[True, False], + ) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights): rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index d2a36c975..0b35d0822 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -15,7 +15,6 @@ from functools import partial import unittest -from absl.testing import parameterized from absl.testing import absltest import numpy as np import scipy.sparse.linalg @@ -93,15 +92,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): M = None return M - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_preconditioner={}".format( - jtu.format_shape_dtype_string(shape, dtype), - preconditioner), - "shape": shape, "dtype": dtype, "preconditioner": preconditioner} - for shape in [(4, 4), (7, 7)] - for dtype in [np.float64, np.complex128] - for preconditioner in [None, 'identity', 'exact', 'random'])) + @jtu.sample_product( + shape=[(4, 4), (7, 7)], + dtype=[np.float64, np.complex128], + preconditioner=[None, 'identity', 'exact', 'random'], + ) def test_cg_against_scipy(self, shape, dtype, preconditioner): if not config.x64_enabled: raise unittest.SkipTest("requires x64 mode") @@ -132,12 +127,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase): args_maker, tol=1e-6) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_shape={jtu.format_shape_dtype_string(shape, dtype)}", - "shape": shape, "dtype": dtype} - for shape in [(2, 2)] - for dtype in float_types + complex_types)) + @jtu.sample_product( + shape=[(2, 2)], + dtype=float_types + complex_types, + ) def test_cg_as_solve(self, shape, dtype): rng = jtu.rand_default(self.rng()) @@ -219,16 +212,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertTrue(dtypes.is_weakly_typed(x)) # BICGSTAB - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_preconditioner={}".format( - jtu.format_shape_dtype_string(shape, dtype), - preconditioner), - "shape": shape, "dtype": dtype, "preconditioner": preconditioner} - for shape in [(5, 5)] - for dtype in [np.float64, np.complex128] - for preconditioner in [None, 'identity', 'exact', 'random'] - )) + @jtu.sample_product( + shape=[(5, 5)], + dtype=[np.float64, np.complex128], + preconditioner=[None, 'identity', 'exact', 'random'], + ) def test_bicgstab_against_scipy( self, shape, dtype, preconditioner): if not config.jax_enable_x64: @@ -266,16 +254,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): args_maker, tol=1e-4) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_preconditioner={}".format( - jtu.format_shape_dtype_string(shape, dtype), - preconditioner), - "shape": shape, "dtype": dtype, "preconditioner": preconditioner} - for shape in [(2, 2), (7, 7)] - for dtype in float_types + complex_types - for preconditioner in [None, 'identity', 'exact'] - )) + @jtu.sample_product( + shape=[(2, 2), (7, 7)], + dtype=float_types + complex_types, + preconditioner=[None, 'identity', 'exact'], + ) @jtu.skip_on_devices("gpu") def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner): A = jnp.eye(shape[1], dtype=dtype) @@ -290,17 +273,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_preconditioner={}".format( - jtu.format_shape_dtype_string(shape, dtype), - preconditioner), - "shape": shape, "dtype": dtype, "preconditioner": preconditioner - } - for shape in [(2, 2), (4, 4)] - for dtype in float_types + complex_types - for preconditioner in [None, 'identity', 'exact'] - )) + @jtu.sample_product( + shape=[(2, 2), (4, 4)], + dtype=float_types + complex_types, + preconditioner=[None, 'identity', 'exact'], + ) @jtu.skip_on_devices("gpu") def test_bicgstab_on_random_system(self, shape, dtype, preconditioner): rng = jtu.rand_default(self.rng()) @@ -339,18 +316,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose(expected, actual) # GMRES - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_preconditioner={}_solve_method={}".format( - jtu.format_shape_dtype_string(shape, dtype), - preconditioner, - solve_method), - "shape": shape, "dtype": dtype, "preconditioner": preconditioner, - "solve_method": solve_method} - for shape in [(3, 3)] - for dtype in [np.float64, np.complex128] - for preconditioner in [None, 'identity', 'exact', 'random'] - for solve_method in ['incremental', 'batched'])) + @jtu.sample_product( + shape=[(3, 3)], + dtype=[np.float64, np.complex128], + preconditioner=[None, 'identity', 'exact', 'random'], + solve_method=['incremental', 'batched'], + ) def test_gmres_against_scipy( self, shape, dtype, preconditioner, solve_method): if not config.x64_enabled: @@ -380,27 +351,20 @@ class LaxBackedScipyTests(jtu.JaxTestCase): partial(scipy_gmres, M=M, restart=2, maxiter=1), partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method), args_maker, - tol=1e-10) + tol=1e-9) self._CheckAgainstNumpy( np.linalg.solve, partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method), args_maker, - tol=1e-10) + tol=1e-8) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_preconditioner={}_solve_method={}".format( - jtu.format_shape_dtype_string(shape, dtype), - preconditioner, - solve_method), - "shape": shape, "dtype": dtype, "preconditioner": preconditioner, - "solve_method": solve_method} - for shape in [(2, 2), (7, 7)] - for dtype in float_types + complex_types - for preconditioner in [None, 'identity', 'exact'] - for solve_method in ['batched', 'incremental'] - )) + @jtu.sample_product( + shape=[(2, 2), (7, 7)], + dtype=float_types + complex_types, + preconditioner=[None, 'identity', 'exact'], + solve_method=['batched', 'incremental'], + ) @jtu.skip_on_devices("gpu") def test_gmres_on_identity_system(self, shape, dtype, preconditioner, solve_method): @@ -419,19 +383,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase): solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_preconditioner={}_solve_method={}".format( - jtu.format_shape_dtype_string(shape, dtype), - preconditioner, - solve_method), - "shape": shape, "dtype": dtype, "preconditioner": preconditioner, - "solve_method": solve_method} - for shape in [(2, 2), (4, 4)] - for dtype in float_types + complex_types - for preconditioner in [None, 'identity', 'exact'] - for solve_method in ['incremental', 'batched'] - )) + @jtu.sample_product( + shape=[(2, 2), (4, 4)], + dtype=float_types + complex_types, + preconditioner=[None, 'identity', 'exact'], + solve_method=['incremental', 'batched'], + ) @jtu.skip_on_devices("gpu") def test_gmres_on_random_system(self, shape, dtype, preconditioner, solve_method): @@ -469,15 +426,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): actual, _ = jax.scipy.sparse.linalg.gmres(A, b) self.assertAllClose(expected, actual) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_preconditioner={}".format( - jtu.format_shape_dtype_string(shape, dtype), - preconditioner), - "shape": shape, "dtype": dtype, "preconditioner": preconditioner} - for shape in [(2, 2), (3, 3)] - for dtype in float_types + complex_types - for preconditioner in [None, 'identity'])) + @jtu.sample_product( + shape=[(2, 2), (3, 3)], + dtype=float_types + complex_types, + preconditioner=[None, 'identity'], + ) def test_gmres_arnoldi_step(self, shape, dtype, preconditioner): """ The Arnoldi decomposition within GMRES is correct. diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 9ecc79348..7ff36560a 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -141,7 +141,8 @@ JAX_SPECIAL_FUNCTION_RECORDS = [ # of inputs to some reasonable intervals op_record("zeta", 2, float_dtypes, jtu.rand_positive, False), # TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise - op_record("expi", 1, [np.float32], jtu.rand_default, True), + op_record("expi", 1, [np.float32], partial(jtu.rand_not_small, offset=0.1), + True), op_record("exp1", 1, [np.float32], jtu.rand_positive, True), op_record("expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)), ] @@ -154,23 +155,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( - jtu.format_shape_dtype_string(shapes, dtype), - axis, keepdims, return_sign, use_b), - # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. - "shapes": shapes, "dtype": dtype, - "axis": axis, "keepdims": keepdims, - "return_sign": return_sign, "use_b": use_b} - for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes + @jtu.sample_product( + [dict(shapes=shapes, axis=axis, use_b=use_b) + for shape_group in compatible_shapes for use_b in [False, True] for shapes in itertools.product(*( (shape_group, shape_group) if use_b else (shape_group,))) for axis in range(-max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) - for keepdims in [False, True] - for return_sign in [False, True])) + ], + dtype=float_dtypes + complex_dtypes + int_dtypes, + keepdims=[False, True], + return_sign=[False, True], + ) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testLogSumExp(self, shapes, dtype, axis, @@ -235,23 +232,23 @@ class LaxBackedScipyTests(jtu.JaxTestCase): result = lsp_special.logsumexp(1.0, b=1.0) self.assertEqual(result, 1.0) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.test_name, shapes, dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, - "test_autodiff": rec.test_autodiff, - "nondiff_argnums": rec.nondiff_argnums, - "scipy_op": getattr(osp_special, rec.name), - "lax_op": getattr(lsp_special, rec.name)} - for shapes in itertools.combinations_with_replacement(all_shapes, rec.nargs) - for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs) - if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) - for rec in JAX_SPECIAL_FUNCTION_RECORDS)) + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(op=rec.name, rng_factory=rec.rng_factory, + test_autodiff=rec.test_autodiff, + nondiff_argnums=rec.nondiff_argnums)], + shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs), + dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs) + if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)), + ) + for rec in JAX_SPECIAL_FUNCTION_RECORDS + )) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion - def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, + def testScipySpecialFun(self, op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): + scipy_op = getattr(osp_special, op) + lax_op = getattr(lsp_special, op) rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() @@ -272,13 +269,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_d={}".format( - jtu.format_shape_dtype_string(shape, dtype), d), - "shape": shape, "dtype": dtype, "d": d} - for shape in all_shapes - for dtype in float_dtypes - for d in [1, 2, 5])) + @jtu.sample_product( + shape=all_shapes, + dtype=float_dtypes, + d=[1, 2, 5], + ) @jax.numpy_rank_promotion('raise') def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): @@ -322,13 +317,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_lmax={}".format( - jtu.format_shape_dtype_string(shape, dtype), l_max), - "l_max": l_max, "shape": shape, "dtype": dtype} - for l_max in [1, 2, 3, 6] - for shape in [(5,), (10,)] - for dtype in float_dtypes)) + @jtu.sample_product( + l_max=[1, 2, 3, 6], + shape=[(5,), (10,)], + dtype=float_dtypes, + ) def testLpmn(self, l_max, shape, dtype): rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -344,13 +337,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): atol=3e-3, check_dtypes=False) self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_lmax={}".format( - jtu.format_shape_dtype_string(shape, dtype), l_max), - "l_max": l_max, "shape": shape, "dtype": dtype} - for l_max in [3, 4, 6, 32] - for shape in [(2,), (3,), (4,), (64,)] - for dtype in float_dtypes)) + @jtu.sample_product( + l_max=[3, 4, 6, 32], + shape=[(2,), (3,), (4,), (64,)], + dtype=float_dtypes, + ) def testNormalizedLpmnValues(self, l_max, shape, dtype): rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -432,11 +423,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8) - @parameterized.named_parameters(jtu.cases_from_list( - {'testcase_name': f'_maxdegree={l_max}_inputsize={num_z}_dtype={dtype.__name__}', - 'l_max': l_max, 'num_z': num_z, 'dtype': dtype} + @jtu.sample_product( + [dict(l_max=l_max, num_z=num_z) for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8]) - for dtype in jtu.dtypes.all_integer)) + ], + dtype=jtu.dtypes.all_integer, + ) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" @@ -475,32 +467,18 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) - @parameterized.named_parameters(jtu.cases_from_list( - {'testcase_name': - '_shape={}' - '_n_zero_sv={}_degeneracy={}_geometric_spectrum={}' - '_max_sv={}_method={}_side={}' - '_nonzero_condition_number={}_seed={}'.format( - jtu.format_shape_dtype_string( - shape, jnp.dtype(dtype).name).replace(" ", ""), - n_zero_sv, degeneracy, geometric_spectrum, max_sv, - method, side, nonzero_condition_number, seed - ), - 'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy, - 'geometric_spectrum': geometric_spectrum, - 'max_sv': max_sv, 'shape': shape, 'method': method, - 'side': side, 'nonzero_condition_number': nonzero_condition_number, - 'dtype': dtype, 'seed': seed} - for n_zero_sv in n_zero_svs - for degeneracy in degeneracies - for geometric_spectrum in geometric_spectra - for max_sv in max_svs - for shape in polar_shapes - for method in methods - for side in sides - for nonzero_condition_number in nonzero_condition_numbers - for dtype in jtu.dtypes.inexact - for seed in seeds)) + @jtu.sample_product( + n_zero_sv=n_zero_svs, + degeneracy=degeneracies, + geometric_spectrum=geometric_spectra, + max_sv=max_svs, + shape=polar_shapes, + method=methods, + side=sides, + nonzero_condition_number=nonzero_condition_numbers, + dtype=jtu.dtypes.inexact, + seed=seeds, + ) def testPolar( self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method, side, nonzero_condition_number, dtype, seed): @@ -552,15 +530,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose( matrix, recon, atol=tol * jnp.linalg.norm(matrix)) - @parameterized.named_parameters(jtu.cases_from_list( - {'testcase_name': - '_linear_size={}_dtype={}_termination_size={}'.format( - linear_size, jnp.dtype(dtype).name, termination_size), - 'linear_size': linear_size, 'dtype': dtype, - 'termination_size': termination_size} - for linear_size in linear_sizes - for dtype in jtu.dtypes.floating + jtu.dtypes.complex - for termination_size in [1, 19])) + @jtu.sample_product( + linear_size=linear_sizes, + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + termination_size=[1, 19], + ) def test_spectral_dac_eigh(self, linear_size, dtype, termination_size): if jtu.device_under_test() != "tpu" and termination_size != 1: raise unittest.SkipTest( @@ -584,13 +558,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase): HV, vV, atol=atol * (80 if jnp.issubdtype(dtype, jnp.complexfloating) else 30)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string((n_obs, n_codes, *n_feats), dtype)}", - "n_obs": n_obs, "n_codes": n_codes, "n_feats": n_feats, "dtype": dtype} - for n_obs in [1, 3, 5] - for n_codes in [1, 2, 4] - for n_feats in [()] + [(i,) for i in range(1, 3)] - for dtype in float_dtypes + int_dtypes)) # scipy doesn't support complex + @jtu.sample_product( + n_obs=[1, 3, 5], + n_codes=[1, 2, 4], + n_feats=[()] + [(i,) for i in range(1, 3)], + dtype=float_dtypes + int_dtypes, # scipy doesn't support complex + ) def test_vq(self, n_obs, n_codes, n_feats, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng((n_obs, *n_feats), dtype), rng((n_codes, *n_feats), dtype)] diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index e374e61f7..2d6328c4c 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -16,7 +16,6 @@ from functools import partial from absl.testing import absltest -from absl.testing import parameterized import numpy as np import numpy.random as npr @@ -35,12 +34,7 @@ npr.seed(0) class MultiBackendTest(jtu.JaxTestCase): """Tests jit targeting to different backends.""" - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_backend={backend}", - "backend": backend, - } - for backend in ['cpu', 'gpu', 'tpu', None] - )) + @jtu.sample_product(backend=['cpu', 'gpu', 'tpu', None]) def testMultiBackend(self, backend): if backend not in ('cpu', jtu.device_under_test(), None): raise SkipTest("Backend is not CPU or the device under test") @@ -56,10 +50,9 @@ class MultiBackendTest(jtu.JaxTestCase): correct_platform = backend if backend else jtu.device_under_test() self.assertEqual(z.device().platform, correct_platform) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_ordering={ordering}", - "ordering": ordering,} - for ordering in [('cpu', None), ('gpu', None), ('tpu', None), (None, None)])) + @jtu.sample_product( + ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)] + ) def testMultiBackendNestedJit(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): @@ -81,14 +74,11 @@ class MultiBackendTest(jtu.JaxTestCase): correct_platform = outer if outer else jtu.device_under_test() self.assertEqual(z.device().platform, correct_platform) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_ordering={ordering}", - "ordering": ordering,} - for ordering in [ - ('cpu', 'gpu'), ('gpu', 'cpu'), - ('cpu', 'tpu'), ('tpu', 'cpu'), - (None, 'cpu'), (None, 'gpu'), (None, 'tpu'), - ])) + @jtu.sample_product( + ordering=[('cpu', 'gpu'), ('gpu', 'cpu'), ('cpu', 'tpu'), ('tpu', 'cpu'), + (None, 'cpu'), (None, 'gpu'), (None, 'tpu'), + ], + ) def testMultiBackendNestedJitConflict(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): @@ -111,11 +101,7 @@ class MultiBackendTest(jtu.JaxTestCase): y = npr.uniform(size=(10, 10)) self.assertRaises(ValueError, lambda: fun(x, y)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_backend={backend}", - "backend": backend,} - for backend in ['cpu', 'gpu', 'tpu'] - )) + @jtu.sample_product(backend=['cpu', 'gpu', 'tpu']) def testGpuMultiBackendOpByOpReturn(self, backend): if backend not in ('cpu', jtu.device_under_test()): raise SkipTest("Backend is not CPU or the device under test") diff --git a/tests/random_test.py b/tests/random_test.py index c2fbebb37..14b446312 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -262,9 +262,7 @@ class PrngTest(jtu.JaxTestCase): expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32) self.assertArraysEqual(bits64, expected64) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_" + name, "prng_name": name} - for name, _ in PRNG_IMPLS)) + @jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS]) def testRngRandomBitsShapeDtype(self, prng_name): # Like testRngRandomBits, but only meant to exercise random_bits # on every PRNG implementation. Instead of values, only checks @@ -313,9 +311,7 @@ class PrngTest(jtu.JaxTestCase): assert np.all(rand_bits_32 == rand_bits_32[0]) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": case._testname(), "case": case} - for case in _RANDOM_VALUES_CASES)) + @jtu.sample_product(case=_RANDOM_VALUES_CASES) @jtu.skip_on_devices("tpu") # TPU precision causes issues. def testRandomDistributionValues(self, case): """ @@ -366,8 +362,7 @@ class PrngTest(jtu.JaxTestCase): _prng_key_as_array(random.fold_in(k, 4)), np.array([2285895361, 433833334], dtype='uint32')) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "seed={seed}_type={type}_jit={jit}".format(**dct), **dct} for dct in [ + @jtu.sample_product([ {"seed": 0, "type": int, "jit": True, "key": [0, 0]}, {"seed": 0, "type": int, "jit": False, "key": [0, 0]}, {"seed": 1, "type": np.int32, "jit": True, "key": [0, 1]}, @@ -390,8 +385,7 @@ class PrngTest(jtu.JaxTestCase): {"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]}, {"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]}, {"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]}, - ] - )) + ]) def test_prng_seeds_and_keys(self, seed, type, jit, key): if (jit and type is int and not config.x64_enabled and (seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)): @@ -525,9 +519,7 @@ class LaxRandomTest(jtu.JaxTestCase): def seed_prng(self, seed): return random.threefry2x32_key(seed) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in jtu.dtypes.floating)) + @jtu.sample_product(dtype=jtu.dtypes.floating) def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype): bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64 numpy_bits = np.array(1., dtype).view(bits_dtype) @@ -535,9 +527,7 @@ class LaxRandomTest(jtu.JaxTestCase): lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))() self.assertEqual(numpy_bits, xla_bits) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in float_dtypes)) + @jtu.sample_product(dtype=float_dtypes) def testRngUniform(self, dtype): key = self.seed_prng(0) rand = lambda key: random.uniform(key, (10000,), dtype) @@ -550,9 +540,7 @@ class LaxRandomTest(jtu.JaxTestCase): self._CheckCollisions(samples, jnp.finfo(dtype).nmant) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in int_dtypes + uint_dtypes)) + @jtu.sample_product(dtype=int_dtypes + uint_dtypes) def testRngRandint(self, dtype): lo = 5 hi = 10 @@ -568,9 +556,7 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertTrue(np.all(lo <= samples)) self.assertTrue(np.all(samples < hi)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in float_dtypes)) + @jtu.sample_product(dtype=float_dtypes) def testNormal(self, dtype): key = self.seed_prng(0) rand = lambda key: random.normal(key, (10000,), dtype) @@ -589,9 +575,7 @@ class LaxRandomTest(jtu.JaxTestCase): res_bfloat16 = random.normal(self.seed_prng(0), dtype=jnp.bfloat16) self.assertAllClose(res_bfloat16, res_bfloat16_str) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in complex_dtypes)) + @jtu.sample_product(dtype=complex_dtypes) def testNormalComplex(self, dtype): key = self.seed_prng(0) rand = lambda key: random.normal(key, (10000,), dtype) @@ -605,9 +589,7 @@ class LaxRandomTest(jtu.JaxTestCase): self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf) self.assertEqual(dtype, samples.dtype) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in float_dtypes)) + @jtu.sample_product(dtype=float_dtypes) def testTruncatedNormal(self, dtype): key = self.seed_prng(0) rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype) @@ -623,9 +605,7 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in jtu.dtypes.floating + jtu.dtypes.integer)) + @jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer) def testShuffle(self, dtype): key = self.seed_prng(0) x = np.arange(100).astype(dtype) @@ -641,23 +621,21 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1), x, check_dtypes=False) - @parameterized.named_parameters(jtu.cases_from_list( - dict( - testcase_name= - f"_{np.dtype(dtype).name}_input_range_or_shape={input_range_or_shape}" - f"_shape={shape}_replace={replace}_weighted={weighted}_axis={axis}", - dtype=dtype, input_range_or_shape=input_range_or_shape, - shape=shape, replace=replace, weighted=weighted, axis=axis) - for dtype in jtu.dtypes.floating + jtu.dtypes.integer - for shape in [(), (5,), (4, 5)] - for replace in [True, False] - for weighted in [True, False] - for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)] - for is_range in [type(input_range_or_shape) is int] - for ndim in [1 if is_range else len(input_range_or_shape)] - for axis in range(-ndim, ndim or 1) - for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]] - if replace or np.prod(shape) <= ninputs)) + @jtu.sample_product( + [dict(shape=shape, replace=replace, axis=axis, + input_range_or_shape=input_range_or_shape) + for shape in [(), (5,), (4, 5)] + for replace in [True, False] + for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)] + for is_range in [type(input_range_or_shape) is int] + for ndim in [1 if is_range else len(input_range_or_shape)] + for axis in range(-ndim, ndim or 1) + for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]] + if replace or np.prod(shape) <= ninputs + ], + dtype=jtu.dtypes.floating + jtu.dtypes.integer, + weighted=[True, False], + ) def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis): # This is the function API that we test against (note that self.rng().choice differs) np_choice = np.random.default_rng(0).choice @@ -690,17 +668,16 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertArraysEqual(sample, jax.jit(rand, static_argnames= 'x' if is_range else None)(key, x)) - @parameterized.named_parameters(jtu.cases_from_list( - dict( - testcase_name=f"_dtype={dtype}_range_or_shape={range_or_shape}" - f"_axis={axis}_independent={independent}", - dtype=dtype, range_or_shape=range_or_shape, axis=axis, independent=independent) - for dtype in jtu.dtypes.floating + jtu.dtypes.integer - for range_or_shape in [0, 1, 100, (0,), (1,), (100,), - (10, 10), (10, 5, 2), (0, 5), (1, 5)] - for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)] - for axis in range(-ndim, ndim or 1) - for independent in [True, False])) + @jtu.sample_product( + [dict(range_or_shape=range_or_shape, axis=axis) + for range_or_shape in [0, 1, 100, (0,), (1,), (100,), + (10, 10), (10, 5, 2), (0, 5), (1, 5)] + for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)] + for axis in range(-ndim, ndim or 1) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.integer, + independent=[True, False], + ) def testPermutation(self, dtype, range_or_shape, axis, independent): key = self.seed_prng(0) is_range = type(range_or_shape) is int @@ -739,11 +716,10 @@ class LaxRandomTest(jtu.JaxTestCase): with self.assertRaises(core.ConcretizationTypeError): jax.jit(random.permutation)(key, 10) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_p={p}_dtype={np.dtype(dtype).name}", - "p": p, "dtype": dtype} - for p in [0.1, 0.5, 0.9] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + p=[0.1, 0.5, 0.9], + dtype=jtu.dtypes.floating, + ) def testBernoulli(self, p, dtype): key = self.seed_prng(0) p = np.array(p, dtype=dtype) @@ -756,17 +732,18 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_p={p}_{np.dtype(dtype).name}_{sample_shape}", - "p": p, "axis": axis, "dtype": dtype, 'sample_shape': sample_shape} - for (p, axis) in [ + @jtu.sample_product( + [dict(p=p, axis=axis) + for (p, axis) in [ ([.25] * 4, -1), ([.1, .2, .3, .4], -1), ([[.5, .5], [.1, .9]], 1), ([[.5, .1], [.5, .9]], 0), - ] - for sample_shape in [(10000,), (5000, 2)] - for dtype in jtu.dtypes.floating)) + ] + ], + sample_shape=[(10000,), (5000, 2)], + dtype=jtu.dtypes.floating, + ) def testCategorical(self, p, axis, dtype, sample_shape): key = self.seed_prng(0) p = np.array(p, dtype=dtype) @@ -800,12 +777,11 @@ class LaxRandomTest(jtu.JaxTestCase): x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_a={a}_b={b}_dtype={np.dtype(dtype).name}", - "a": a, "b": b, "dtype": dtype} - for a in [0.2, 5.] - for b in [0.2, 5.] - for dtype in [np.float64])) # NOTE: KS test fails with float32 + @jtu.sample_product( + a=[0.2, 5.], + b=[0.2, 5.], + dtype=[np.float64], # NOTE: KS test fails with float32 + ) def testBeta(self, a, b, dtype): if not config.x64_enabled: raise SkipTest("skip test except on X64") @@ -832,9 +808,7 @@ class LaxRandomTest(jtu.JaxTestCase): ones = samples[samples >= 0.5] self.assertAllClose(ones, jnp.ones_like(ones)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in float_dtypes)) + @jtu.sample_product(dtype=float_dtypes) def testCauchy(self, dtype): key = self.seed_prng(0) rand = lambda key: random.cauchy(key, (10000,), dtype) @@ -846,13 +820,10 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_alpha={alpha}_dtype={np.dtype(dtype).name}", - "alpha": alpha, "dtype": dtype} - for alpha in [ - np.array([0.2, 1., 5.]), - ] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + alpha=[np.array([0.2, 1., 5.]),], + dtype=jtu.dtypes.floating, + ) @jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times def testDirichlet(self, alpha, dtype): key = self.seed_prng(0) @@ -883,9 +854,7 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertAllClose(samples.max(1), jnp.ones(samples.shape[0]), check_dtypes=False, rtol=1E-5) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in float_dtypes)) + @jtu.sample_product(dtype=float_dtypes) def testExponential(self, dtype): key = self.seed_prng(0) rand = lambda key: random.exponential(key, (10000,), dtype) @@ -897,11 +866,10 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_a={a}_dtype={np.dtype(dtype).name}", - "a": a, "dtype": dtype} - for a in [0.1, 1., 10.] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + a=[0.1, 1., 10.], + dtype=jtu.dtypes.floating, + ) def testGammaVsLogGamma(self, a, dtype): key = self.seed_prng(0) rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype) @@ -911,11 +879,10 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a))) self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a))) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_a={a}_dtype={np.dtype(dtype).name}", - "a": a, "dtype": dtype} - for a in [0.1, 1., 10.] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + a=[0.1, 1., 10.], + dtype=jtu.dtypes.floating, + ) def testGamma(self, a, dtype): key = self.seed_prng(0) rand = lambda key, a: random.gamma(key, a, (10000,), dtype) @@ -932,11 +899,10 @@ class LaxRandomTest(jtu.JaxTestCase): x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_a={alpha}_logspace={log_space}", - "alpha": alpha, "log_space": log_space} - for log_space in [True, False] - for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4])) + @jtu.sample_product( + log_space=[True, False], + alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4], + ) def testGammaGrad(self, log_space, alpha): rng = self.seed_prng(0) alphas = np.full((100,), alpha) @@ -969,11 +935,10 @@ class LaxRandomTest(jtu.JaxTestCase): # Should not crash with a type error. jax.vjp(f, a, b) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_lam={lam}_dtype={np.dtype(dtype).name}", - "lam": lam, "dtype": np.dtype(dtype)} - for lam in [0.5, 3, 9, 11, 50, 500] - for dtype in [np.int16, np.int32, np.int64])) + @jtu.sample_product( + lam=[0.5, 3, 9, 11, 50, 500], + dtype=[np.int16, np.int32, np.int64], + ) def testPoisson(self, lam, dtype): key = self.seed_prng(0) rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype) @@ -1019,9 +984,7 @@ class LaxRandomTest(jtu.JaxTestCase): samples = random.poisson(key, lam, shape=(3,)) self.assertArraysEqual(samples, jnp.array([-1, 0, -1])) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in jtu.dtypes.floating)) + @jtu.sample_product(dtype=jtu.dtypes.floating) def testGumbel(self, dtype): key = self.seed_prng(0) rand = lambda key: random.gumbel(key, (10000,), dtype) @@ -1033,9 +996,7 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in float_dtypes)) + @jtu.sample_product(dtype=float_dtypes) def testLaplace(self, dtype): key = self.seed_prng(0) rand = lambda key: random.laplace(key, (10000,), dtype) @@ -1047,9 +1008,7 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype} - for dtype in float_dtypes)) + @jtu.sample_product(dtype=float_dtypes) def testLogistic(self, dtype): key = self.seed_prng(0) rand = lambda key: random.logistic(key, (10000,), dtype) @@ -1061,15 +1020,11 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_n={}_shape={}"\ - .format(n, jtu.format_shape_dtype_string(shape, dtype)), - "n": n, - "shape": shape, - "dtype": dtype} - for n in range(1, 5) - for shape in [(), (5,), (10, 5)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + @jtu.sample_product( + n=range(1, 5), + shape=[(), (5,), (10, 5)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def testOrthogonal(self, n, shape, dtype): key = self.seed_prng(0) q = random.orthogonal(key, n, shape, dtype) @@ -1083,15 +1038,11 @@ class LaxRandomTest(jtu.JaxTestCase): atol=tol, rtol=tol, ) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_p={}_shape={}"\ - .format(p, jtu.format_shape_dtype_string(shape, dtype)), - "p": p, - "shape": shape, - "dtype": dtype} - for p in [.5, 1., 1.5, 2., 2.5] - for shape in [(), (5,), (10, 5)] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + p=[.5, 1., 1.5, 2., 2.5], + shape=[(), (5,), (10, 5)], + dtype=jtu.dtypes.floating, + ) def testGeneralizedNormal(self, p, shape, dtype): key = self.seed_prng(0) rand = lambda key, p: random.generalized_normal(key, p, shape, dtype) @@ -1103,17 +1054,12 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertEqual(samples.dtype, dtype) self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_d={}_p={}_shape={}"\ - .format(d, p, jtu.format_shape_dtype_string(shape, dtype)), - "d": d, - "p": p, - "shape": shape, - "dtype": dtype} - for d in range(1, 5) - for p in [.5, 1., 1.5, 2., 2.5] - for shape in [(), (5,), (10, 5)] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + d=range(1, 5), + p=[.5, 1., 1.5, 2., 2.5], + shape=[(), (5,), (10, 5)], + dtype=jtu.dtypes.floating, + ) def testBall(self, d, p, shape, dtype): key = self.seed_prng(0) rand = lambda key, p: random.ball(key, d, p, shape, dtype) @@ -1127,11 +1073,10 @@ class LaxRandomTest(jtu.JaxTestCase): norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p) self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_b={b}_dtype={np.dtype(dtype).name}", - "b": b, "dtype": dtype} - for b in [0.1, 1., 10.] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + b=[0.1, 1., 10.], + dtype=jtu.dtypes.floating, + ) def testPareto(self, b, dtype): key = self.seed_prng(0) rand = lambda key, b: random.pareto(key, b, (10000,), dtype) @@ -1149,11 +1094,10 @@ class LaxRandomTest(jtu.JaxTestCase): x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_df={df}_dtype={np.dtype(dtype).name}", - "df": df, "dtype": dtype} - for df in [0.1, 1., 10.] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + df=[0.1, 1., 10.], + dtype=jtu.dtypes.floating, + ) @jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times def testT(self, df, dtype): key = self.seed_prng(0) @@ -1166,13 +1110,11 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_dim={}_dtype={}_method={}".format( - dim, np.dtype(dtype), method), - "dim": dim, "dtype": dtype, "method": method} - for dim in [1, 3, 5] - for dtype in float_dtypes - for method in ['svd', 'eigh', 'cholesky'])) + @jtu.sample_product( + dim=[1, 3, 5], + dtype=float_dtypes, + method=['svd', 'eigh', 'cholesky'], + ) def testMultivariateNormal(self, dim, dtype, method): r = self.rng() mean = r.randn(dim) @@ -1198,18 +1140,13 @@ class LaxRandomTest(jtu.JaxTestCase): # eigenvectors follow a standard normal distribution. self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_dim={}_mean_batch_size={}_cov_batch_size={}_shape={}_method={}"\ - .format(dim, mean_batch_size, cov_batch_size, shape, method), - "dim": dim, - "mean_batch_size": mean_batch_size, - "cov_batch_size": cov_batch_size, - "shape": shape, "method": method} - for dim in [1, 2, 4] - for mean_batch_size in [(), (3,), (2, 3)] - for cov_batch_size in [(), (3,), (2, 3)] - for shape in [(), (1,), (5,)] - for method in ['cholesky', 'svd', 'eigh'])) + @jtu.sample_product( + dim=[1, 2, 4], + mean_batch_size=[(), (3,), (2, 3)], + cov_batch_size=[(), (3,), (2, 3)], + shape=[(), (1,), (5,)], + method=['cholesky', 'svd', 'eigh'], + ) def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size, shape, method): r = self.rng() @@ -1426,10 +1363,10 @@ class LaxRandomTest(jtu.JaxTestCase): with jax.enable_checks(False): # check_jaxpr will materialize array jax.eval_shape(f, 0) # doesn't error - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_seed={seed}_type={type_}", "seed": seed, "type_": type_} - for type_ in ["int", "np.array", "jnp.array"] - for seed in [-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)])) + @jtu.sample_product( + type_=["int", "np.array", "jnp.array"], + seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)], + ) def test_prng_jit_invariance(self, seed, type_): if type_ == "int" and seed == (1 << 64) - 1: self.skipTest("Expected failure: Python int too large.") @@ -1453,9 +1390,7 @@ class LaxRandomTest(jtu.JaxTestCase): jax.jit(random.split)(key) self.assertLessEqual(count[0], 1) # 1 for the argument device_put - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={dtype}", "dtype": dtype} - for dtype in int_dtypes + uint_dtypes)) + @jtu.sample_product(dtype=int_dtypes + uint_dtypes) def test_randint_bounds(self, dtype): min = np.iinfo(dtype).min max = np.iinfo(dtype).max @@ -1541,10 +1476,7 @@ class KeyArrayTest(jtu.JaxTestCase): self.assertIsInstance(keys, random.KeyArray) self.assertEqual(keys.shape, (3,)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_internal" if use_internal else "", - "use_internal": use_internal} - for use_internal in [False, True])) + @jtu.sample_product(use_internal=[False, True]) def test_random_unwrap(self, use_internal): unwrap = prng_internal.random_unwrap if use_internal else random.key_data def f(k): return unwrap(k) diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index af2d5a10f..75d22bd0f 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -16,7 +16,7 @@ from functools import partial import unittest -from absl.testing import absltest, parameterized +from absl.testing import absltest import numpy as np import scipy.signal as osp_signal @@ -70,22 +70,19 @@ def _complex_dtype(dtype): class LaxBackedScipySignalTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format( - op, - jtu.format_shape_dtype_string(xshape, dtype), - jtu.format_shape_dtype_string(yshape, dtype), - mode), - "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, - "jsp_op": getattr(jsp_signal, op), - "osp_op": getattr(osp_signal, op)} - for mode in ['full', 'same', 'valid'] - for op in ['convolve', 'correlate'] - for dtype in default_dtypes - for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes] - for xshape in shapeset - for yshape in shapeset)) - def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op): + @jtu.sample_product( + [dict(xshape=xshape, yshape=yshape) + for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes] + for xshape in shapeset + for yshape in shapeset + ], + mode=['full', 'same', 'valid'], + op=['convolve', 'correlate'], + dtype=default_dtypes, + ) + def testConvolutions(self, xshape, yshape, dtype, mode, op): + jsp_op = getattr(jsp_signal, op) + osp_op = getattr(osp_signal, op) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) @@ -94,21 +91,16 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "op={}_xshape={}_yshape={}_mode={}".format( - op, - jtu.format_shape_dtype_string(xshape, dtype), - jtu.format_shape_dtype_string(yshape, dtype), - mode), - "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, - "jsp_op": getattr(jsp_signal, op), - "osp_op": getattr(osp_signal, op)} - for mode in ['full', 'same', 'valid'] - for op in ['convolve2d', 'correlate2d'] - for dtype in default_dtypes - for xshape in twodim_shapes - for yshape in twodim_shapes)) - def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op): + @jtu.sample_product( + mode=['full', 'same', 'valid'], + op=['convolve2d', 'correlate2d'], + dtype=default_dtypes, + xshape=twodim_shapes, + yshape=twodim_shapes, + ) + def testConvolutions2D(self, xshape, yshape, dtype, mode, op): + jsp_op = getattr(jsp_signal, op) + osp_op = getattr(osp_signal, op) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) @@ -118,15 +110,13 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}_type={}_bp={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, type, bp), - "shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp} - for shape in [(5,), (4, 5), (3, 4, 5)] - for dtype in jtu.dtypes.floating + jtu.dtypes.integer - for axis in [0, -1] - for type in ['constant', 'linear'] - for bp in [0, [0, 2]])) + @jtu.sample_product( + shape=[(5,), (4, 5), (3, 4, 5)], + dtype=jtu.dtypes.floating + jtu.dtypes.integer, + axis=[0, -1], + type=['constant', 'linear'], + bp=[0, [0, 2]], + ) def testDetrend(self, shape, dtype, axis, type, bp): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -144,24 +134,19 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" - f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}" - f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}" - f"_axis={timeaxis}_nfft={nfft}", - "shape": shape, "dtype": dtype, "fs": fs, "window": window, - "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, - "detrend": detrend, "boundary": boundary, "padded": padded, - "timeaxis": timeaxis} + @jtu.sample_product( + [dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis, + nfft=nfft) for shape, nperseg, noverlap, timeaxis in stft_test_shapes - for dtype in default_dtypes - for fs in [1.0, 16000.0] - for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] - for detrend in ['constant', 'linear', False] - for boundary in [None, 'even', 'odd', 'zeros'] - for padded in [True, False])) + ], + dtype=default_dtypes, + fs=[1.0, 16000.0], + window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'], + detrend=['constant', 'linear', False], + boundary=[None, 'even', 'odd', 'zeros'], + padded=[True, False], + ) def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, boundary, padded, timeaxis): @@ -193,26 +178,19 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): # Tests with `average == 'median'`` is excluded from `testCsd*` # due to the issue: # https://github.com/scipy/scipy/issues/15601 - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}" - f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}" - f"_average={average}_scaling={scaling}_nfft={nfft}" - f"_fs={fs}_window={window}_detrend={detrend}" - f"_nperseg={nperseg}_noverlap={noverlap}" - f"_axis={timeaxis}", - "xshape": xshape, "yshape": yshape, "dtype": dtype, "fs": fs, - "window": window, "nperseg": nperseg, "noverlap": noverlap, - "nfft": nfft, "detrend": detrend, "scaling": scaling, - "timeaxis": timeaxis, "average": average} + @jtu.sample_product( + [dict(xshape=xshape, yshape=yshape, nperseg=nperseg, noverlap=noverlap, + timeaxis=timeaxis, nfft=nfft) for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes - for dtype in default_dtypes - for fs in [1.0, 16000.0] - for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] - for detrend in ['constant', 'linear', False] - for scaling in ['density', 'spectrum'] - for average in ['mean'])) + ], + dtype=default_dtypes, + fs=[1.0, 16000.0], + window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'], + detrend=['constant', 'linear', False], + scaling=['density', 'spectrum'], + average=['mean'], + ) def testCsdAgainstNumpy( self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): @@ -242,25 +220,19 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" - f"_average={average}_scaling={scaling}_nfft={nfft}" - f"_fs={fs}_window={window}_detrend={detrend}" - f"_nperseg={nperseg}_noverlap={noverlap}" - f"_axis={timeaxis}", - "shape": shape, "dtype": dtype, "fs": fs, - "window": window, "nperseg": nperseg, "noverlap": noverlap, - "nfft": nfft, "detrend": detrend, "scaling": scaling, - "timeaxis": timeaxis, "average": average} - for shape, unused_yshape, nperseg, noverlap, timeaxis in csd_test_shapes - for dtype in default_dtypes - for fs in [1.0, 16000.0] - for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] + @jtu.sample_product( + [dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis, + nfft=nfft) + for shape, _yshape, nperseg, noverlap, timeaxis in csd_test_shapes for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] - for detrend in ['constant', 'linear', False] - for scaling in ['density', 'spectrum'] - for average in ['mean'])) + ], + dtype=default_dtypes, + fs=[1.0, 16000.0], + window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'], + detrend=['constant', 'linear', False], + scaling=['density', 'spectrum'], + average=['mean'], + ) def testCsdWithSameParamAgainstNumpy( self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): @@ -292,26 +264,20 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" - f"_fs={fs}_window={window}" - f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}" - f"_detrend={detrend}_return_onesided={return_onesided}" - f"_scaling={scaling}_axis={timeaxis}_average={average}", - "shape": shape, "dtype": dtype, "fs": fs, "window": window, - "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, - "detrend": detrend, "return_onesided": return_onesided, - "scaling": scaling, "timeaxis": timeaxis, "average": average} + @jtu.sample_product( + [dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis, + nfft=nfft) for shape, nperseg, noverlap, timeaxis in welch_test_shapes - for dtype in default_dtypes - for fs in [1.0, 16000.0] - for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] - for detrend in ['constant', 'linear', False] - for return_onesided in [True, False] - for scaling in ['density', 'spectrum'] - for average in ['mean', 'median'])) + ], + dtype=default_dtypes, + fs=[1.0, 16000.0], + window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'], + detrend=['constant', 'linear', False], + return_onesided=[True, False], + scaling=['density', 'spectrum'], + average=['mean', 'median'], + ) def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, timeaxis, average): @@ -342,20 +308,14 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" - f"_nperseg={nperseg}_noverlap={noverlap}" - f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}" - f"_axis={timeaxis}", - "shape": shape, "dtype": dtype, - "nperseg": nperseg, "noverlap": noverlap, - "use_nperseg": use_nperseg, "use_noverlap": use_noverlap, - "timeaxis": timeaxis} + @jtu.sample_product( + [dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis) for shape, nperseg, noverlap, timeaxis in welch_test_shapes - for use_nperseg in [False, True] - for use_noverlap in [False, True] - for dtype in jtu.dtypes.floating + jtu.dtypes.integer)) + ], + use_nperseg=[False, True], + use_noverlap=[False, True], + dtype=jtu.dtypes.floating + jtu.dtypes.integer, + ) def testWelchWithDefaultStepArgsAgainstNumpy( self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap, timeaxis): @@ -386,23 +346,18 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" - f"_fs={fs}_window={window}_boundary={boundary}" - f"_nperseg={nperseg}_noverlap={noverlap}_onesided={onesided}" - f"_timeaxis={timeaxis}_freqaxis{freqaxis}_nfft={nfft}", - "shape": shape, "dtype": dtype, "fs": fs, "window": window, - "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, - "onesided": onesided, "boundary": boundary, - "timeaxis": timeaxis, "freqaxis": freqaxis} + @jtu.sample_product( + [dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis, + freqaxis=freqaxis, nfft=nfft) for shape, nperseg, noverlap, timeaxis, freqaxis in istft_test_shapes - for dtype in default_dtypes - for fs in [1.0, 16000.0] - for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] - for onesided in [False, True] - for boundary in [False, True])) + ], + dtype=default_dtypes, + fs=[1.0, 16000.0], + window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'], + onesided=[False, True], + boundary=[False, True], + ) def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, onesided, boundary, timeaxis, freqaxis): diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 4f22f9595..04bec1324 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -16,7 +16,7 @@ from functools import partial import itertools -from absl.testing import absltest, parameterized +from absl.testing import absltest import numpy as np import scipy.stats as osp_stats @@ -34,12 +34,10 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] def genNamedParametersNArgs(n): - return parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), - "shapes": shapes, "dtypes": dtypes} - for shapes in itertools.combinations_with_replacement(all_shapes, n) - for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n))) + return jtu.sample_product( + shapes=itertools.combinations_with_replacement(all_shapes, n), + dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, n), + ) # Allow implicit rank promotion in these tests, as virtually every test exercises it. @@ -179,14 +177,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes), - "shapes": [x_shape, alpha_shape], "dtypes": dtypes} + @jtu.sample_product( + shapes=[ + [x_shape, alpha_shape] for x_shape in one_and_two_dim_shapes for alpha_shape in [(x_shape[0],), (x_shape[0] + 1,)] - for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, 2) - )) + ], + dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, 2), + ) def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) @@ -564,13 +562,12 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)), check_dtypes=False) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", [shape, shape], [x_dtype, p_dtype]), - "x_dtype": x_dtype, - "p_dtype": p_dtype, - "shape": shape} - for shape in [(2), (4,), (1, 5)] - for (x_dtype, p_dtype) in itertools.product(jtu.dtypes.integer, jtu.dtypes.floating))) + @jtu.sample_product( + [dict(x_dtype=x_dtype, p_dtype=p_dtype) + for x_dtype, p_dtype in itertools.product(jtu.dtypes.integer, jtu.dtypes.floating) + ], + shape=[(2), (4,), (1, 5)], + ) def testMultinomialLogPmf(self, shape, x_dtype, p_dtype): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.multinomial.logpmf @@ -588,16 +585,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_x={}_mean={}_cov={}".format( - jtu.format_shape_dtype_string(x_shape, x_dtype), - jtu.format_shape_dtype_string(mean_shape, mean_dtype) - if mean_shape is not None else None, - jtu.format_shape_dtype_string(cov_shape, cov_dtype) - if cov_shape is not None else None), - "x_shape": x_shape, "x_dtype": x_dtype, - "mean_shape": mean_shape, "mean_dtype": mean_dtype, - "cov_shape": cov_shape, "cov_dtype": cov_dtype} + @jtu.sample_product( + [dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape) for x_shape, mean_shape, cov_shape in [ # # These test cases cover default values for mean/cov, but we don't # # support those yet (and they seem not very valuable). @@ -616,9 +605,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): [(3, 4), (4,), (4, 4)], [(2, 3, 4), (4,), (4, 4)], ] - for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3) - if (mean_shape is not None or mean_dtype == np.float32) - and (cov_shape is not None or cov_dtype == np.float32))) + ], + [dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype) + for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3) + ], + # if (mean_shape is not None or mean_dtype == np.float32) + # and (cov_shape is not None or cov_dtype == np.float32))) + ) def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) @@ -642,16 +635,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): rtol=1e-4, atol=1e-4) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_x={}_mean={}_cov={}".format( - jtu.format_shape_dtype_string(x_shape, x_dtype), - jtu.format_shape_dtype_string(mean_shape, mean_dtype) - if mean_shape is not None else None, - jtu.format_shape_dtype_string(cov_shape, cov_dtype) - if cov_shape is not None else None), - "x_shape": x_shape, "x_dtype": x_dtype, - "mean_shape": mean_shape, "mean_dtype": mean_dtype, - "cov_shape": cov_shape, "cov_dtype": cov_dtype} + @jtu.sample_product( + [dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape) for x_shape, mean_shape, cov_shape in [ # These test cases are where scipy flattens things, which has # different batch semantics than some might expect, so we manually @@ -663,9 +648,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): [(1, 3, 2), (3, 2,), (5, 1, 2, 2)], [(5, 3, 2), (1, 2,), (2, 2)], ] - for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3) - if (mean_shape is not None or mean_dtype == np.float32) - and (cov_shape is not None or cov_dtype == np.float32))) + ], + [dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype) + for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3) + ], + ) def testMultivariateNormalLogpdfBroadcasted(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) @@ -691,12 +678,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): rtol=1e-4, atol=1e-4) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_ndim={ndim}_nbatch={nbatch}_dtype={dtype.__name__}", - "ndim": ndim, "nbatch": nbatch, "dtype": dtype} - for ndim in [2, 3] - for nbatch in [1, 3, 5] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + ndim=[2, 3], + nbatch=[1, 3, 5], + dtype=jtu.dtypes.floating, + ) def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype): # Regression test for #5570 rng = jtu.rand_default(self.rng()) @@ -709,23 +695,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysEqual(result1, result2, check_dtypes=False) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_inshape={}_outsize={}_weights={}_method={}_func={}".format( - jtu.format_shape_dtype_string(inshape, dtype), - outsize, weights, method, func), - "dtype": dtype, - "inshape": inshape, - "outsize": outsize, - "weights": weights, - "method": method, - "func": func} - for inshape in [(50,), (3, 50), (2, 12)] - for dtype in jtu.dtypes.floating - for outsize in [None, 10] - for weights in [False, True] - for method in [None, "scott", "silverman", 1.5, "callable"] - for func in [None, "evaluate", "logpdf", "pdf"])) + @jtu.sample_product( + inshape=[(50,), (3, 50), (2, 12)], + dtype=jtu.dtypes.floating, + outsize=[None, 10], + weights=[False, True], + method=[None, "scott", "silverman", 1.5, "callable"], + func=[None, "evaluate", "logpdf", "pdf"], + ) def testKde(self, inshape, dtype, outsize, weights, method, func): if method == "callable": method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4)) @@ -764,12 +741,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), - "dtype": dtype, - "shape": shape} - for shape in [(15,), (3, 15), (1, 12)] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + shape=[(15,), (3, 15), (1, 12)], + dtype=jtu.dtypes.floating, + ) def testKdeIntegrateGaussian(self, shape, dtype): def scipy_fun(dataset, weights): kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) @@ -796,12 +771,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), - "dtype": dtype, - "shape": shape} - for shape in [(15,), (12,)] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + shape=[(15,), (12,)], + dtype=jtu.dtypes.floating, + ) def testKdeIntegrateBox1d(self, shape, dtype): def scipy_fun(dataset, weights): kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) @@ -820,12 +793,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), - "dtype": dtype, - "shape": shape} - for shape in [(15,), (3, 15), (1, 12)] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + shape=[(15,), (3, 15), (1, 12)], + dtype=jtu.dtypes.floating, + ) def testKdeIntegrateKde(self, shape, dtype): def scipy_fun(dataset, weights): kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) @@ -848,12 +819,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), - "dtype": dtype, - "shape": shape} - for shape in [(15,), (3, 15), (1, 12)] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + shape=[(15,), (3, 15), (1, 12)], + dtype=jtu.dtypes.floating, + ) def testKdeResampleShape(self, shape, dtype): def resample(key, dataset, weights, *, shape): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) @@ -878,12 +847,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): result = func(*args) assert result.shape == (ndim, 4) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), - "dtype": dtype, - "shape": shape} - for shape in [(15,), (1, 12)] - for dtype in jtu.dtypes.floating)) + @jtu.sample_product( + shape=[(15,), (1, 12)], + dtype=jtu.dtypes.floating, + ) def testKdeResample1d(self, shape, dtype): rng = jtu.rand_default(self.rng()) dataset = rng(shape, dtype) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index a67f24d3e..34b2c76d2 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -150,11 +150,10 @@ class cuSparseTest(jtu.JaxTestCase): yield self.assertEmpty(caught_warnings) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}", - "shape": shape, "dtype": dtype} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in all_dtypes)) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + ) def test_csr_todense(self, shape, dtype): rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix) M = rng(shape, dtype) @@ -166,11 +165,10 @@ class cuSparseTest(jtu.JaxTestCase): with self.gpu_dense_conversion_warning_context(dtype): self.assertArraysEqual(M.toarray(), jit(todense)(*args)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}", - "shape": shape, "dtype": dtype} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_csr_todense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) @@ -189,11 +187,10 @@ class cuSparseTest(jtu.JaxTestCase): self.assertArraysEqual(primals, f(data)) self.assertArraysEqual(data_out, data) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}", - "shape": shape, "dtype": dtype} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_csr_fromdense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) @@ -218,17 +215,17 @@ class cuSparseTest(jtu.JaxTestCase): self.assertArraysEqual(M_out, M) @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), - jtu.format_shape_dtype_string(bshape, dtype)), - "shape": shape, "dtype": dtype, "bshape": bshape} + @jtu.sample_product( + [dict(shape=shape, bshape=bshape) for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for bshape in [shape[-1:] + s for s in [(), (1,), (3,)]] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_csr_matmul_ad(self, shape, dtype, bshape): csr_matmul = sparse.csr_matvec if len(bshape) == 1 else sparse.csr_matmat - tol = {np.float32: 1E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12} + tol = {np.float32: 2E-5, np.float64: 1E-12, np.complex64: 1E-5, + np.complex128: 1E-12} rng = rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) @@ -273,11 +270,10 @@ class cuSparseTest(jtu.JaxTestCase): self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol) self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}", - "shape": shape, "dtype": dtype} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in all_dtypes)) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + ) def test_csr_fromdense(self, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -298,12 +294,11 @@ class cuSparseTest(jtu.JaxTestCase): self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype)) self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}_T={transpose}", - "shape": shape, "dtype": dtype, "transpose": transpose} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in all_dtypes - for transpose in [True, False])) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + transpose=[True, False], + ) @jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1 def test_csr_matvec(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M @@ -320,12 +315,11 @@ class cuSparseTest(jtu.JaxTestCase): with self.gpu_matmul_warning_context(dtype): self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}_T={transpose}", - "shape": shape, "dtype": dtype, "transpose": transpose} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in all_dtypes - for transpose in [True, False])) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + transpose=[True, False], + ) def test_csr_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M @@ -341,11 +335,10 @@ class cuSparseTest(jtu.JaxTestCase): with self.gpu_matmul_warning_context(dtype): self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}", - "shape": shape, "dtype": dtype} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in all_dtypes)) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + ) def test_coo_todense(self, shape, dtype): rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix) M = rng(shape, dtype) @@ -357,11 +350,10 @@ class cuSparseTest(jtu.JaxTestCase): with self.gpu_dense_conversion_warning_context(dtype): self.assertArraysEqual(M.toarray(), jit(todense)(*args)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}", - "shape": shape, "dtype": dtype} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in all_dtypes)) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + ) def test_coo_fromdense(self, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -382,12 +374,11 @@ class cuSparseTest(jtu.JaxTestCase): self.assertArraysEqual(row, M_coo.row.astype(index_dtype)) self.assertArraysEqual(col, M_coo.col.astype(index_dtype)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}_T={transpose}", - "shape": shape, "dtype": dtype, "transpose": transpose} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in all_dtypes - for transpose in [True, False])) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + transpose=[True, False], + ) def test_coo_matvec(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M @@ -403,12 +394,11 @@ class cuSparseTest(jtu.JaxTestCase): with self.gpu_matmul_warning_context(dtype): self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}_T={transpose}", - "shape": shape, "dtype": dtype, "transpose": transpose} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in all_dtypes - for transpose in [True, False])) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + transpose=[True, False], + ) @jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1 def test_coo_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M @@ -534,13 +524,11 @@ class cuSparseTest(jtu.JaxTestCase): self.assertIn(sparse.csr_todense_p, mlir._platform_specific_lowerings["rocm"]) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), mat_type), - "shape": shape, "dtype": dtype, "mat_type": mat_type} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex - for mat_type in ['csr', 'coo'])) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + mat_type=['csr', 'coo'], + ) def test_extra_nse(self, shape, dtype, mat_type): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -554,11 +542,10 @@ class cuSparseTest(jtu.JaxTestCase): M_out = todense(*args, shape=M.shape) self.assertArraysEqual(M, M_out) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}", - "shape": shape, "dtype": dtype} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_coo_todense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) @@ -576,11 +563,10 @@ class cuSparseTest(jtu.JaxTestCase): self.assertArraysEqual(primals, f(data)) self.assertArraysEqual(data_out, data) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{jtu.format_shape_dtype_string(shape, dtype)}", - "shape": shape, "dtype": dtype} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + @jtu.sample_product( + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_coo_fromdense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) @@ -605,14 +591,13 @@ class cuSparseTest(jtu.JaxTestCase): self.assertArraysEqual(M_out, M) @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), - jtu.format_shape_dtype_string(bshape, dtype)), - "shape": shape, "dtype": dtype, "bshape": bshape} + @jtu.sample_product( + [dict(shape=shape, bshape=bshape) for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for bshape in [shape[-1:] + s for s in [(), (1,), (3,)]] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_coo_matmul_ad(self, shape, dtype, bshape): coo_matmul = sparse_coo._coo_matvec if len(bshape) == 1 else sparse_coo._coo_matmat tol = {np.float32: 1E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12} @@ -681,14 +666,14 @@ class BCOOTest(jtu.JaxTestCase): self.assertEqual(repr(x), "DynamicJaxprTracer[BCOO(float32[5], nse=4)]") f(x) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in all_dtypes for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=all_dtypes, + ) def test_empty(self, shape, dtype, n_batch, n_dense): M = sparse.empty(shape, dtype=dtype, n_batch=n_batch, n_dense=n_dense) self.assertIsInstance(M, sparse.BCOO) @@ -698,16 +683,16 @@ class BCOOTest(jtu.JaxTestCase): self.assertEqual(M.dtype, dtype) self.assertArraysEqual(M.todense(), jnp.empty(shape, dtype)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_k={}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string((N, M), dtype), k, n_batch, n_dense), - "N": N, "M": M, "k": k, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} - for N in [3, 5] - for M in [None, 4] - for k in [-3, -1, 0, 2, 4] - for dtype in all_dtypes + @jtu.sample_product( + [dict(n_batch=n_batch, n_dense=n_dense) for n_batch in range(3) - for n_dense in range(3 - n_batch))) + for n_dense in range(3 - n_batch) + ], + N=[3, 5], + M=[None, 4], + k=[-3, -1, 0, 2, 4], + dtype=all_dtypes, + ) def test_eye(self, N, M, k, dtype, n_batch, n_dense): mat = sparse.eye(N, M, k, dtype=dtype, n_batch=n_batch, n_dense=n_dense) expected = jnp.eye(N, M, k, dtype=dtype) @@ -720,14 +705,14 @@ class BCOOTest(jtu.JaxTestCase): self.assertEqual(mat.nse, expected_nse) self.assertArraysEqual(mat.todense(), expected) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in all_dtypes for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=all_dtypes, + ) def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -748,14 +733,14 @@ class BCOOTest(jtu.JaxTestCase): self.assertArraysEqual(M, todense(data, indices)) self.assertArraysEqual(M, jit(todense)(data, indices)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_todense_ad(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -771,14 +756,14 @@ class BCOOTest(jtu.JaxTestCase): self.assertEqual(j1.shape, M.shape + data.shape) self.assertEqual(hess.shape, M.shape + 2 * data.shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_fromdense_ad(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -814,14 +799,14 @@ class BCOOTest(jtu.JaxTestCase): self.assertTrue(mat_unsorted.unique_indices) self.assertTrue(mat_resorted.unique_indices) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_dense_round_trip_batched(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -845,14 +830,14 @@ class BCOOTest(jtu.JaxTestCase): self.assertArraysEqual(M, todense(data, indices)) self.assertArraysEqual(M, jit(todense)(data, indices)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_extract(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -884,14 +869,14 @@ class BCOOTest(jtu.JaxTestCase): actual = vmap(sparse.bcoo_extract, in_axes=0)(indices, mat) self.assertArraysEqual(expected, actual) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -907,14 +892,14 @@ class BCOOTest(jtu.JaxTestCase): self.assertEqual(j1.shape, data.shape + M.shape) self.assertEqual(hess.shape, data.shape + 2 * M.shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_transpose(self, shape, dtype, n_batch, n_dense): n_sparse = len(shape) - n_batch - n_dense rng = self.rng() @@ -942,14 +927,14 @@ class BCOOTest(jtu.JaxTestCase): Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense) self.assertArraysEqual(trans(M), trans(Msp).todense()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_slice(self, shape, dtype, n_batch, n_dense): rng = self.rng() sprng = rand_sparse(rng) @@ -970,14 +955,14 @@ class BCOOTest(jtu.JaxTestCase): self.assertArraysEqual(dense_result, sparse_result.todense()) self.assertArraysEqual(dense_result, sparse_result_jit.todense()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_dynamic_slice(self, shape, dtype, n_batch, n_dense): rng = self.rng() sprng = rand_sparse(rng) @@ -997,16 +982,15 @@ class BCOOTest(jtu.JaxTestCase): self.assertArraysEqual(dense_result, sparse_result.todense()) self.assertArraysEqual(dense_result, sparse_result_jit.todense()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}_idx={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, idx), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, - "idx": idx} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) + 1 - n_batch) - for idx in [1, slice(1, 3)])) + ], + dtype=jtu.dtypes.floating, + idx=[1, slice(1, 3)], + ) def test_bcoo_getitem(self, shape, dtype, n_batch, n_dense, idx): # Note: __getitem__ is currently only supported for simple slices and indexing rng = self.rng() @@ -1015,14 +999,14 @@ class BCOOTest(jtu.JaxTestCase): Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense) self.assertArraysEqual(M[idx], Msp[idx].todense()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_transpose_ad(self, shape, dtype, n_batch, n_dense): n_sparse = len(shape) - n_batch - n_dense rng = self.rng() @@ -1067,14 +1051,14 @@ class BCOOTest(jtu.JaxTestCase): mat_T_indices_unsorted = mat.transpose(axes=permutations) self.assertFalse(mat_T_indices_unsorted.indices_sorted) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex for n_batch in range(1, len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -1090,12 +1074,12 @@ class BCOOTest(jtu.JaxTestCase): M4 = sparse_bcoo._bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, spinfo=BCOOInfo(M.shape)) self.assertAllClose(M3, M4) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": props.testcase_name(), "props": props} - for props in _generate_bcoo_dot_general_properties( - shapes=[(5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)], - dtypes=jtu.dtypes.floating + jtu.dtypes.complex, - ))) + @jtu.sample_product( + props=_generate_bcoo_dot_general_properties( + shapes=[(5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)], + dtypes=jtu.dtypes.floating + jtu.dtypes.complex, + ) + ) def test_bcoo_dot_general(self, props: BcooDotGeneralProperties): rng = jtu.rand_small(self.rng()) rng_sparse = rand_sparse(self.rng()) @@ -1123,14 +1107,9 @@ class BCOOTest(jtu.JaxTestCase): @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - lhs_contracting, rhs_contracting), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting} + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, + lhs_contracting=lhs_contracting, rhs_contracting=rhs_contracting) for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [(5,), (5,), [0], [0]], [(5,), (5, 7), [0], [0]], @@ -1142,7 +1121,9 @@ class BCOOTest(jtu.JaxTestCase): [(5, 3), (2, 5), [0], [1]], [(5, 3), (5, 2), [0], [0]], ] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_dot_general_cusparse( self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting): rng = jtu.rand_small(self.rng()) @@ -1168,15 +1149,9 @@ class BCOOTest(jtu.JaxTestCase): @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_n_batch={}_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}" - .format(n_batch, jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - lhs_contracting, rhs_contracting), - "n_batch": n_batch, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, - "dtype": dtype, "lhs_contracting": lhs_contracting, - "rhs_contracting": rhs_contracting} + @jtu.sample_product( + [dict(n_batch=n_batch, lhs_shape=lhs_shape, rhs_shape=rhs_shape, + lhs_contracting=lhs_contracting, rhs_contracting=rhs_contracting) for n_batch, lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [1, (1, 2, 3), (3, 2), [2], [0]], [1, (1, 3, 2), (3, 2), [1], [0]], @@ -1185,7 +1160,9 @@ class BCOOTest(jtu.JaxTestCase): [1, (4, 2, 3), (2, 5), [1], [0]], [1, (4, 2, 3), (5, 3), [2], [1]], ] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) @jtu.skip_on_devices("rocm") def test_bcoo_batched_matmat_cusparse( self, n_batch, lhs_shape, rhs_shape, dtype, lhs_contracting, @@ -1231,20 +1208,16 @@ class BCOOTest(jtu.JaxTestCase): @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_n_batch={}_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}" - .format(n_batch, jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - lhs_contracting, rhs_contracting), - "n_batch": n_batch, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, - "dtype": dtype, "lhs_contracting": lhs_contracting, - "rhs_contracting": rhs_contracting} + @jtu.sample_product( + [dict(n_batch=n_batch, lhs_shape=lhs_shape, rhs_shape=rhs_shape, + lhs_contracting=lhs_contracting, rhs_contracting=rhs_contracting) for n_batch, lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [1, (1, 2, 3), (3), [2], [0]], [1, (1, 2), (3, 2), [1], [1]], ] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) @jtu.skip_on_devices("rocm") def test_bcoo_batched_matmat_default_lowering( self, n_batch, lhs_shape, rhs_shape, dtype, lhs_contracting, @@ -1328,12 +1301,12 @@ class BCOOTest(jtu.JaxTestCase): with self.subTest(msg="1D"): self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": props.testcase_name(), "props": props} - for props in _generate_bcoo_dot_general_properties( - shapes=[(5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)], - dtypes=jtu.dtypes.floating + jtu.dtypes.complex, - ))) + @jtu.sample_product( + props=_generate_bcoo_dot_general_properties( + shapes=[(5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)], + dtypes=jtu.dtypes.floating + jtu.dtypes.complex, + ) + ) def test_bcoo_rdot_general(self, props: BcooDotGeneralProperties): rng = jtu.rand_small(self.rng()) rng_sparse = rand_sparse(self.rng()) @@ -1364,15 +1337,9 @@ class BCOOTest(jtu.JaxTestCase): # TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why? # self._CompileAndCheck(f_sparse, args_maker) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - dimension_numbers, n_batch, n_dense), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "dimension_numbers": dimension_numbers, - "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape, + rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0), ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0), @@ -1381,7 +1348,9 @@ class BCOOTest(jtu.JaxTestCase): ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1), ] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_small(self.rng()) @@ -1404,15 +1373,9 @@ class BCOOTest(jtu.JaxTestCase): X = sparse_bcoo._bcoo_todense(data, indices, spinfo=BCOOInfo(X.shape)) self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - dimension_numbers, n_batch, n_dense), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "dimension_numbers": dimension_numbers, - "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape, + rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [ ((4, 5), (5, 3), (([1], [0]), ([], [])), 0, 0), ((2, 4, 5), (2, 5, 3), (([2], [1]), ([0], [0])), 1, 0), @@ -1424,7 +1387,9 @@ class BCOOTest(jtu.JaxTestCase): # This requires contraction over dense dimensions, which is not yet implemented: # ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1), ] - for dtype in jtu.dtypes.floating)) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_dot_general_ad(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_small(self.rng()) @@ -1484,15 +1449,9 @@ class BCOOTest(jtu.JaxTestCase): extract = jax.vmap(extract) self.assertAllClose(extract(jf_dense), jf_sparse, rtol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - dimension_numbers, n_batch, n_dense), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "dimension_numbers": dimension_numbers, - "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape, + rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 0, 0), ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0), @@ -1503,7 +1462,9 @@ class BCOOTest(jtu.JaxTestCase): ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 1, 2), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1), ] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_dot_general_sampled(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_default(self.rng()) sprng = rand_sparse(self.rng()) @@ -1531,15 +1492,9 @@ class BCOOTest(jtu.JaxTestCase): # TODO: python_should_be_executing check occasionally fails... why? # self._CompileAndCheck(sparse_fun, args_maker) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - dimension_numbers, n_batch, n_dense), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "dimension_numbers": dimension_numbers, - "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape, + rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0), ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 1), @@ -1550,7 +1505,9 @@ class BCOOTest(jtu.JaxTestCase): ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1), ] - for dtype in jtu.dtypes.floating)) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_dot_general_sampled_ad(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_default(self.rng()) sprng = rand_sparse(self.rng()) @@ -1584,14 +1541,9 @@ class BCOOTest(jtu.JaxTestCase): self.assertAllClose(jf_sparse, jr_sparse, atol=tol) @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_swap={}_dims={}".format( - jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch, - jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch, - swap, dimension_numbers), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, - "lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch, - "dimension_numbers": dimension_numbers, "swap": swap, "dtype": dtype} + @jtu.sample_product( + [dict(lhs_n_batch=lhs_n_batch, rhs_n_batch=rhs_n_batch, lhs_shape=lhs_shape, + rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dimension_numbers in [ # (batched) outer products (no contraction) ((5,), 0, (6,), 0, (([], []), ([], []))), @@ -1619,8 +1571,10 @@ class BCOOTest(jtu.JaxTestCase): ((2, 3, 4, 3), 1, (2, 4, 3, 4), 1, (([2, 3], [1, 2]), ([0], [0]))), ((2, 3, 4, 3, 1), 2, (3, 2, 3, 4), 2, (([2, 3], [3, 2]), ([0, 1], [1, 0]))), ] - for swap in [True, False] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + ], + swap=[True, False], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_spdot_general(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, swap, dimension_numbers): if swap: dimension_numbers = tuple(d[::-1] for d in dimension_numbers) @@ -1672,22 +1626,20 @@ class BCOOTest(jtu.JaxTestCase): # matrix-matrix product -> product of nse N = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4)) self.assertEqual((M @ N).nse, M.nse * N.nse) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}[n_batch={}]_rhs_shape={}[n_batch={}]_dimension_numbers={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch, - jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch, - dimension_numbers), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "dimension_numbers": dimension_numbers, - "lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch} + + @jtu.sample_product( + [dict(lhs_n_batch=lhs_n_batch, rhs_n_batch=rhs_n_batch, lhs_shape=lhs_shape, + rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dimension_numbers in [ ((4, 5), 0, (5,), 0, (([1], [0]), ([], []))), ((2, 4, 5), 1, (5,), 0, (([2], [0]), ([], []))), ((4, 5), 0, (5, 3), 0, (([1], [0]), ([], []))), ((2, 4, 5), 1, (2, 5, 3), 1, (([2], [1]), ([0], [0]))), ] - for dtype in jtu.dtypes.floating)) + ], + dtype=jtu.dtypes.floating, + ) + @jax.default_matmul_precision("float32") def test_bcoo_spdot_general_ad(self, lhs_shape, rhs_shape, dtype, dimension_numbers, lhs_n_batch, rhs_n_batch): rng = rand_sparse(self.rng()) @@ -1758,14 +1710,9 @@ class BCOOTest(jtu.JaxTestCase): self.assertAllClose(sp_de_jac, de_de_jac) @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_in_axes={}".format( - jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch, - jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch, - in_axes), - "lhs_shape": lhs_shape, "lhs_n_batch": lhs_n_batch, - "rhs_shape": rhs_shape, "rhs_n_batch": rhs_n_batch, - "dtype": dtype, "in_axes": in_axes} + @jtu.sample_product( + [dict(lhs_n_batch=lhs_n_batch, rhs_n_batch=rhs_n_batch, lhs_shape=lhs_shape, + rhs_shape=rhs_shape, in_axes=in_axes) for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, in_axes in [ ((3, 5), 1, (3, 5), 1, 0), ((3, 4, 5), 1, (3, 5), 1, 0), @@ -1775,7 +1722,9 @@ class BCOOTest(jtu.JaxTestCase): # ((3, 4, 5), 1, (5,), 0, (0, None)), # ((4, 5), 0, (3, 5), 1, (None, 0)), ] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_spmm_batched(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, in_axes): sprng = rand_sparse(self.rng()) def args_maker(): @@ -1797,17 +1746,16 @@ class BCOOTest(jtu.JaxTestCase): result_sparse_jit = jax.jit(f_sparse)(*args) self.assertAllClose(result_dense, result_sparse_jit.todense()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}_nse={}_remove_zeros={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, nse, remove_zeros), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, - "nse": nse, "remove_zeros": remove_zeros} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense, nse=nse) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) + 1 - n_batch) for nse in [None, np.prod(shape) - 1] - for remove_zeros in [True, False])) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + remove_zeros=[True, False], + ) def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse, remove_zeros): # Create a matrix with duplicate indices rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) @@ -1834,16 +1782,15 @@ class BCOOTest(jtu.JaxTestCase): self.assertTrue(M_dedup.unique_indices) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}_nse={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, nse), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "nse": nse} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense, nse=nse) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) + 1 - n_batch) for nse in [None, 5, np.prod(shape) - 1] - )) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_sum_duplicates_ad(self, shape, dtype, n_batch, n_dense, nse): # Create a matrix with duplicate indices rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) @@ -1869,14 +1816,14 @@ class BCOOTest(jtu.JaxTestCase): self.assertAllClose(data_dot_fwd, data_dot_rev) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_sort_indices(self, shape, dtype, n_batch, n_dense): rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) @@ -1892,14 +1839,14 @@ class BCOOTest(jtu.JaxTestCase): sorted = jax.vmap(jnp.lexsort)(flatind[:, ::-1]) self.assertArraysEqual(sorted, lax.broadcasted_iota(sorted.dtype, sorted.shape, sorted.ndim - 1)) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating, + ) def test_bcoo_sort_indices_ad(self, shape, dtype, n_batch, n_dense): rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) @@ -1971,16 +1918,16 @@ class BCOOTest(jtu.JaxTestCase): self.assertArraysEqual(x.indices, y.indices) self.assertArraysEqual(x.data, y.data) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}_axes={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, axes), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "axes": axes} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense, axes=axes) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) + 1 - n_batch) for naxes in range(len(shape)) - for axes in itertools.combinations(range(len(shape)), naxes))) + for axes in itertools.combinations(range(len(shape)), naxes) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -2003,13 +1950,8 @@ class BCOOTest(jtu.JaxTestCase): y.reshape(2, 3, 2) @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - } + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape, rhs_shape in [[(3,), (3,)], [(3, 4), (4,)], [(4,), (4, 5)], @@ -2017,8 +1959,10 @@ class BCOOTest(jtu.JaxTestCase): [(3, 4), (2, 4, 5)], [(2, 3, 4), (4, 5)], [(2, 3, 4), (2, 4, 5)]] - for lhs_dtype in all_dtypes - for rhs_dtype in all_dtypes)) + ], + lhs_dtype=all_dtypes, + rhs_dtype=all_dtypes, + ) def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): rng = jtu.rand_default(self.rng()) lhs = jnp.array(rng(lhs_shape, lhs_dtype)) @@ -2039,22 +1983,18 @@ class BCOOTest(jtu.JaxTestCase): self.assertAllClose(out1, out2, rtol=tol) self.assertAllClose(out1, out3, rtol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_n_batch={}_n_dense={}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), - n_batch, n_dense), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "n_batch": n_batch, "n_dense": n_dense, - } + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, n_batch=n_batch, + n_dense=n_dense) for lhs_shape, rhs_shape in [[(3,), ()], [(3,), (1,)], [(3,), (3,)], [(3, 4), ()], [(3, 4), (4,)], [(3, 4), (3, 1)], [(3, 4), (3, 4)], [(3, 4, 5), (4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]] for n_batch in range(len(lhs_shape) + 1) for n_dense in range(len(lhs_shape) + 1 - n_batch) - for lhs_dtype in all_dtypes - for rhs_dtype in all_dtypes)) + ], + lhs_dtype=all_dtypes, + rhs_dtype=all_dtypes, + ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def test_bcoo_mul_dense(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense): rng_lhs = rand_sparse(self.rng()) @@ -2073,14 +2013,10 @@ class BCOOTest(jtu.JaxTestCase): np.float32: 1E-6, np.complex64: 1E-6} self.assertAllClose(out1, out2, rtol=tol) self.assertAllClose(out1, out3, rtol=tol) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_n_batch={}_{}_n_batch={}_n_dense={}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), lhs_n_batch, - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), rhs_n_batch, n_dense), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch, "n_dense": n_dense, - } + + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, lhs_n_batch=lhs_n_batch, + rhs_n_batch=rhs_n_batch, n_dense=n_dense) # TODO(jakevdp): add broadcasted shapes (from bcoo_mul_dense) once sparse-sparse mul # supports inputs of differing rank. for lhs_shape, rhs_shape in [[(3,), (1,)], [(3,), (3,)], @@ -2090,8 +2026,10 @@ class BCOOTest(jtu.JaxTestCase): for lhs_n_batch in range(len(lhs_shape) + 1) for rhs_n_batch in range(len(lhs_shape) + 1) for n_dense in range(len(lhs_shape) + 1 - max(lhs_n_batch, rhs_n_batch)) - for lhs_dtype in all_dtypes - for rhs_dtype in all_dtypes)) + ], + lhs_dtype=all_dtypes, + rhs_dtype=all_dtypes, + ) def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n_batch, rhs_n_batch, n_dense): rng = rand_sparse(self.rng()) lhs = jnp.array(rng(lhs_shape, lhs_dtype)) @@ -2116,14 +2054,14 @@ class BCOOTest(jtu.JaxTestCase): mat = sparse.BCOO((data, indices), shape=(3, 3)) self.assertArraysEqual((mat * mat).todense(), mat.todense() * mat.todense()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_n_batch={}_n_dense={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(), (3,), (3, 5), (3, 5, 4)] - for dtype in all_dtypes for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=all_dtypes, + ) def test_bcoo_broadcast_in_dim(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) x = jnp.array(rng(shape, dtype)) @@ -2141,15 +2079,15 @@ class BCOOTest(jtu.JaxTestCase): self.assertArraysEqual(xsp[:, :, None].todense(), x[:, :, None]) self.assertArraysEqual(xsp[:, None, :, None].todense(), x[:, None, :, None]) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_n_batch={}_n_dense={}_dimension={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, dimension), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "dimension": dimension} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense, dimension=dimension) for shape in [ (3,), (3, 5), (3, 5, 4)] - for dtype in all_dtypes for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) + 1 - n_batch) - for dimension in range(len(shape) - n_dense))) # Concatenation of dense dimensions not implemented. + for dimension in range(len(shape) - n_dense) # Concatenation of dense dimensions not implemented. + ], + dtype=all_dtypes, + ) def test_bcoo_concatenate(self, shape, dtype, n_batch, n_dense, dimension): rng = rand_sparse(self.rng()) operands_dense = [rng(shape, dtype) for i in range(3)] @@ -2178,18 +2116,17 @@ class BCOOTest(jtu.JaxTestCase): self.assertEqual(Msp_dense.shape, M.shape) self.assertArraysEqual(Msp_dense, M) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}->{}_ndense={}->{}".format( - jtu.format_shape_dtype_string(shape, dtype), - n_batch, n_batch_out, n_dense, n_dense_out), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, - "n_batch_out": n_batch_out, "n_dense_out": n_dense_out} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense, n_batch_out=n_batch_out, + n_dense_out=n_dense_out) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.integer for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) + 1 - n_batch) for n_batch_out in range(len(shape) + 1) - for n_dense_out in range(len(shape) + 1 - n_batch_out))) + for n_dense_out in range(len(shape) + 1 - n_batch_out) + ], + dtype=jtu.dtypes.integer, + ) def test_bcoo_update_layout(self, shape, dtype, n_batch, n_batch_out, n_dense, n_dense_out): rng = rand_sparse(self.rng()) mat = sparse.BCOO.fromdense(rng(shape, dtype), n_batch=n_batch, n_dense=n_dense) @@ -2252,13 +2189,13 @@ class BCOOTest(jtu.JaxTestCase): # TODO(tianjianlu): Unify the testing for BCOOTest and BCSRTest. class BCSRTest(jtu.JaxTestCase): - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch), - "shape": shape, "dtype": dtype, "n_batch": n_batch} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch) for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex - for n_batch in range(len(shape) - 1))) + for n_batch in range(len(shape) - 1) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcsr_dense_round_trip(self, shape, dtype, n_batch): n_sparse = 2 n_dense = len(shape) - n_sparse - n_batch @@ -2452,13 +2389,12 @@ class SparseObjectTest(jtu.JaxTestCase): else: raise ValueError("Obj={Obj} not expected.") - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "_{}_Obj={}".format( - jtu.format_shape_dtype_string(shape, dtype), Obj.__name__), - "shape": shape, "dtype": dtype, "Obj": Obj} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex) + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + Obj=[Obj], + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) def test_dense_round_trip(self, shape, dtype, Obj): rng = rand_sparse(self.rng()) @@ -2466,13 +2402,12 @@ class SparseObjectTest(jtu.JaxTestCase): Msparse = Obj.fromdense(M) self.assertArraysEqual(M, Msparse.todense()) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "_{}_Obj={}".format( - jtu.format_shape_dtype_string(shape, dtype), Obj.__name__), - "shape": shape, "dtype": dtype, "Obj": Obj} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex) + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + Obj=[Obj], + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) def test_transpose(self, shape, dtype, Obj): rng = rand_sparse(self.rng()) @@ -2481,14 +2416,15 @@ class SparseObjectTest(jtu.JaxTestCase): self.assertArraysEqual(M.T, Msparse.T.todense()) @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision") - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "_{}_Obj={}_bshape={}".format( - jtu.format_shape_dtype_string(shape, dtype), Obj.__name__, bshape), - "shape": shape, "dtype": dtype, "Obj": Obj, "bshape": bshape} - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for bshape in [shape[-1:] + s for s in [(), (3,), (4,)]] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex) + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(shape=shape, bshape=bshape) + for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] + for bshape in [shape[-1:] + s for s in [(), (3,), (4,)]] + ], + Obj=[Obj], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) def test_matmul(self, shape, dtype, Obj, bshape): rng = rand_sparse(self.rng(), post=jnp.array) @@ -2507,14 +2443,12 @@ class SparseObjectTest(jtu.JaxTestCase): with jax.numpy_dtype_promotion('standard'): self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}({})".format( - input_type.__name__, - jtu.format_shape_dtype_string(shape, dtype)), - "input_type": input_type, "shape": shape, "dtype": dtype} - for input_type in [scipy.sparse.coo_matrix, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix] - for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + @jtu.sample_product( + input_type=[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix, + scipy.sparse.csc_matrix], + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_from_scipy_sparse(self, input_type, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -2537,13 +2471,13 @@ class SparseObjectTest(jtu.JaxTestCase): self.assertArraysEqual(M.sum(1), Msp.sum(1).todense()) self.assertArraysEqual(M.sum(), Msp.sum()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch), - "shape": shape, "dtype": dtype, "n_batch": n_batch} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch) for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex - for n_batch in range(len(shape) - 1))) + for n_batch in range(len(shape) - 1) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) def test_bcoo_to_bcsr_round_trip(self, shape, dtype, n_batch): rng = rand_sparse(self.rng()) M = rng(shape, dtype) @@ -2572,16 +2506,15 @@ class SparseObjectTest(jtu.JaxTestCase): class SparseRandomTest(jtu.JaxTestCase): - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_indices_dtype={}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, dtype), indices_dtype, n_batch, n_dense), - "shape": shape, "dtype": dtype, "indices_dtype": indices_dtype, - "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating - for indices_dtype in jtu.dtypes.integer for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) + 1 - n_batch))) + for n_dense in range(len(shape) + 1 - n_batch) + ], + dtype=jtu.dtypes.floating, + indices_dtype=jtu.dtypes.integer, + ) def test_random_bcoo(self, shape, dtype, indices_dtype, n_batch, n_dense): key = jax.random.PRNGKey(1701) mat = sparse.random_bcoo( @@ -2604,13 +2537,11 @@ class SparseRandomTest(jtu.JaxTestCase): class SparseSolverTest(jtu.JaxTestCase): - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_re{}_({})".format(reorder, - jtu.format_shape_dtype_string((size, size), dtype)), - "size": size, "reorder": reorder, "dtype": dtype} - for size in [20, 50, 100] - for reorder in [0, 1, 2, 3] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + @jtu.sample_product( + size=[20, 50, 100], + reorder=[0, 1, 2, 3], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/cusolver") @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU") @unittest.skipIf(xla_extension_version < 86, "test requires jaxlib version 86") @@ -2632,20 +2563,19 @@ class SparseSolverTest(jtu.JaxTestCase): return sparse.linalg.spsolve(data, indices, indptr, b, tol, reorder) x = sparse_solve(data, indices, indptr, b) - self.assertAllClose(a @ x, b, rtol=1e-2) + self.assertAllClose(a @ x, b, rtol=1e-2, atol=1e-3) self._CompileAndCheck(sparse_solve, args_maker) class SparseUtilTest(jtu.JaxTestCase): - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "dtype_{}_nbatch={}_ndense={}_nse={}".format( - dtype, n_batch, n_dense, expected_nse), - "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, - "expected_nse": expected_nse} + @jtu.sample_product( + [dict(n_batch=n_batch, n_dense=n_dense, expected_nse=expected_nse) for n_batch, n_dense, expected_nse in - [(0, 0, 4), (1, 0, 2), (0, 1, 2), (2, 0, 1), (1, 1, 1), (0, 2, 1)] - for dtype in all_dtypes)) + [(0, 0, 4), (1, 0, 2), (0, 1, 2), (2, 0, 1), (1, 1, 1), (0, 2, 1)] + ], + dtype=all_dtypes, + ) def test_count_stored_elements(self, dtype, n_batch, n_dense, expected_nse): """Test counting nse.""" mat = np.array([[1, 0, 2, 0], [0, 0, 0, 0], [0, 3, 0, 4]], dtype=dtype) @@ -2653,15 +2583,14 @@ class SparseUtilTest(jtu.JaxTestCase): mat, n_batch=n_batch, n_dense=n_dense) self.assertEqual(expected_nse, actual_nse) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "dtype_{}_nbatch={}_ndense={}_nse={}".format( - dtype, n_batch, n_dense, expected_nse), - "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, - "expected_nse": expected_nse} + @jtu.sample_product( + [dict(n_batch=n_batch, n_dense=n_dense, expected_nse=expected_nse) for n_batch, n_dense, expected_nse in - [(0, 0, 14), (1, 0, np.array([6, 8])), (0, 1, 9), - (2, 0, np.array([[3, 3], [4, 4]]))] - for dtype in all_dtypes)) + [(0, 0, 14), (1, 0, np.array([6, 8])), (0, 1, 9), + (2, 0, np.array([[3, 3], [4, 4]]))] + ], + dtype=all_dtypes + ) def test_count_stored_elements_per_batch(self, dtype, n_batch, n_dense, expected_nse): """Test counting nse.""" diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 216372daf..150e76869 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -205,18 +205,15 @@ class SparsifyTest(jtu.JaxTestCase): self.assertAllClose(out.todense(), x.todense() + y.todense()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_nbatch={}_ndense={}_unique_indices={}".format( - jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, - unique_indices), - "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, - "unique_indices": unique_indices} + @jtu.sample_product( + [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in (jtu.dtypes.integer + jtu.dtypes.floating + - jtu.dtypes.complex) for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) + 1 - n_batch) - for unique_indices in [True, False])) + ], + dtype=jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex, + unique_indices=[True, False], + ) def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices): rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, @@ -281,10 +278,8 @@ class SparsifyTest(jtu.JaxTestCase): res_sparse = res_sparse.todense() self.assertArraysAllClose(res_dense, res_sparse) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_dimensions={}_nbatch={}_ndense={}".format( - jtu.format_shape_dtype_string(shape, np.float32), dimensions, n_batch, n_dense), - "shape": shape, "dimensions": dimensions, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, dimensions=dimensions, n_batch=n_batch, n_dense=n_dense) for shape, dimensions in [ [(1,), (0,)], [(1,), (-1,)], @@ -294,7 +289,9 @@ class SparsifyTest(jtu.JaxTestCase): [(2, 1, 3, 1), (3,)], ] for n_batch in range(len(shape) + 1) - for n_dense in range(len(shape) - n_batch + 1))) + for n_dense in range(len(shape) - n_batch + 1) + ], + ) def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense): rng = jtu.rand_default(self.rng()) @@ -307,9 +304,8 @@ class SparsifyTest(jtu.JaxTestCase): self.assertAllClose(result_sparse, result_dense) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_shapes={shapes}_func={func}_nbatch={n_batch}", - "shapes": shapes, "func": func, "n_batch": n_batch} + @jtu.sample_product( + [dict(shapes=shapes, func=func, n_batch=n_batch) for shapes, func, n_batch in [ ([(4,), (4,)], "concatenate", 0), ([(4,), (4,)], "stack", 0), @@ -328,7 +324,9 @@ class SparsifyTest(jtu.JaxTestCase): ([(2, 4), (2, 5)], "hstack", 2), ([(2, 4), (4,), (3, 4)], "vstack", 0), ([(1, 4), (4,), (1, 4)], "vstack", 0), - ])) + ] + ], + ) def testSparseConcatenate(self, shapes, func, n_batch): f = self.sparsify(getattr(jnp, func)) rng = jtu.rand_some_zero(self.rng()) @@ -336,9 +334,8 @@ class SparsifyTest(jtu.JaxTestCase): sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs] self.assertArraysEqual(f(arrs), f(sparrs).todense()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}", - "shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense} + @jtu.sample_product( + [dict(shape=shape, new_shape=new_shape, n_batch=n_batch, n_dense=n_dense) for shape, new_shape, n_batch, n_dense in [ [(6,), (2, 3), 0, 0], [(1, 4), (2, 2), 0, 0], @@ -348,7 +345,9 @@ class SparsifyTest(jtu.JaxTestCase): [(2, 3, 4), (3, 8), 0, 0], [(2, 3, 4), (1, 2, 12), 1, 0], [(2, 3, 4), (6, 2, 2), 2, 0], - ])) + ] + ], + ) def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') @@ -359,10 +358,9 @@ class SparsifyTest(jtu.JaxTestCase): self.assertArraysEqual(arr2, arr2_sparse.todense()) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}_dimensions={dimensions}", - "shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense, - "dimensions": dimensions} + @jtu.sample_product( + [dict(shape=shape, new_shape=new_shape, n_batch=n_batch, n_dense=n_dense, + dimensions=dimensions) for shape, new_shape, n_batch, n_dense, dimensions in [ [(2, 3, 4), (24,), 0, 0, None], [(2, 3, 4), (24,), 0, 0, (0, 1, 2)], @@ -375,7 +373,9 @@ class SparsifyTest(jtu.JaxTestCase): [(4, 2, 3), (2, 2, 6), 1, 0, (0, 2, 1)], [(2, 3, 4), (6, 4), 2, 0, (0, 1, 2)], [(2, 3, 4), (6, 4), 2, 0, (1, 0, 2)], - ])) + ] + ], + ) def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch, n_dense, dimensions): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') diff --git a/tests/stax_test.py b/tests/stax_test.py index e927a8699..eb05b62b7 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -15,7 +15,6 @@ """Tests for Stax library.""" from absl.testing import absltest -from absl.testing import parameterized import numpy as np @@ -50,105 +49,81 @@ def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape): @jtu.with_config(jax_numpy_rank_promotion="allow") class StaxTest(jtu.JaxTestCase): - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_shape={shape}", "shape": shape} - for shape in [(2, 3), (5,)])) + @jtu.sample_product(shape=[(2, 3), (5,)]) def testRandnInitShape(self, shape): key = random.PRNGKey(0) out = stax.randn()(key, shape) self.assertEqual(out.shape, shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_shape={shape}", "shape": shape} - for shape in [(2, 3), (2, 3, 4)])) + @jtu.sample_product(shape=[(2, 3), (2, 3, 4)]) def testGlorotInitShape(self, shape): key = random.PRNGKey(0) out = stax.glorot()(key, shape) self.assertEqual(out.shape, shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" - .format(channels, filter_shape, padding, strides, input_shape), - "channels": channels, "filter_shape": filter_shape, "padding": padding, - "strides": strides, "input_shape": input_shape} - for channels in [2, 3] - for filter_shape in [(1, 1), (2, 3)] - for padding in ["SAME", "VALID"] - for strides in [None, (2, 1)] - for input_shape in [(2, 10, 11, 1)])) + @jtu.sample_product( + channels=[2, 3], + filter_shape=[(1, 1), (2, 3)], + padding=["SAME", "VALID"], + strides=[None, (2, 1)], + input_shape=[(2, 10, 11, 1)], + ) def testConvShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" - .format(channels, filter_shape, padding, strides, input_shape), - "channels": channels, "filter_shape": filter_shape, "padding": padding, - "strides": strides, "input_shape": input_shape} - for channels in [2, 3] - for filter_shape in [(1, 1), (2, 3), (3, 3)] - for padding in ["SAME", "VALID"] - for strides in [None, (2, 1), (2, 2)] - for input_shape in [(2, 10, 11, 1)])) + @jtu.sample_product( + channels=[2, 3], + filter_shape=[(1, 1), (2, 3), (3, 3)], + padding=["SAME", "VALID"], + strides=[None, (2, 1), (2, 2)], + input_shape=[(2, 10, 11, 1)], + ) def testConvTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.ConvTranspose(channels, filter_shape, # 2D strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" - .format(channels, filter_shape, padding, strides, input_shape), - "channels": channels, "filter_shape": filter_shape, "padding": padding, - "strides": strides, "input_shape": input_shape} - for channels in [2, 3] - for filter_shape in [(1,), (2,), (3,)] - for padding in ["SAME", "VALID"] - for strides in [None, (1,), (2,)] - for input_shape in [(2, 10, 1)])) + + @jtu.sample_product( + channels=[2, 3], + filter_shape=[(1,), (2,), (3,)], + padding=["SAME", "VALID"], + strides=[None, (1,), (2,)], + input_shape=[(2, 10, 1)], + ) def testConv1DTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv1DTranspose(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_out_dim={}_input_shape={}" - .format(out_dim, input_shape), - "out_dim": out_dim, "input_shape": input_shape} - for out_dim in [3, 4] - for input_shape in [(2, 3), (3, 4)])) + @jtu.sample_product( + out_dim=[3, 4], + input_shape=[(2, 3), (3, 4)], + ) def testDenseShape(self, out_dim, input_shape): init_fun, apply_fun = stax.Dense(out_dim) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_input_shape={}_nonlinear={}" - .format(input_shape, nonlinear), - "input_shape": input_shape, "nonlinear": nonlinear} - for input_shape in [(2, 3), (2, 3, 4)] - for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"])) + @jtu.sample_product( + input_shape=[(2, 3), (2, 3, 4)], + nonlinear=["Relu", "Sigmoid", "Elu", "LeakyRelu"], + ) def testNonlinearShape(self, input_shape, nonlinear): init_fun, apply_fun = getattr(stax, nonlinear) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}" - "_maxpool={}_spec={}" - .format(window_shape, padding, strides, input_shape, - max_pool, spec), - "window_shape": window_shape, "padding": padding, "strides": strides, - "input_shape": input_shape, "max_pool": max_pool, "spec": spec} - for window_shape in [(1, 1), (2, 3)] - for padding in ["VALID"] - for strides in [None, (2, 1)] - for input_shape in [(2, 5, 6, 4)] - for max_pool in [False, True] - for spec in ["NHWC", "NCHW", "WHNC", "WHCN"])) + @jtu.sample_product( + window_shape=[(1, 1), (2, 3)], + padding=["VALID"], + strides=[None, (2, 1)], + input_shape=[(2, 5, 6, 4)], + max_pool=[False, True], + spec=["NHWC", "NCHW", "WHNC", "WHCN"], + ) def testPoolingShape(self, window_shape, padding, strides, input_shape, max_pool, spec): layer = stax.MaxPool if max_pool else stax.AvgPool @@ -156,49 +131,41 @@ class StaxTest(jtu.JaxTestCase): spec=spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_shape={input_shape}", - "input_shape": input_shape} - for input_shape in [(2, 3), (2, 3, 4)])) + @jtu.sample_product(input_shape=[(2, 3), (2, 3, 4)]) def testFlattenShape(self, input_shape): init_fun, apply_fun = stax.Flatten _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_input_shape={input_shape}_spec={i}", - "input_shape": input_shape, "spec": spec} - for input_shape in [(2, 5, 6, 1)] - for i, spec in enumerate([ - [stax.Conv(3, (2, 2))], - [stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]]))) + @jtu.sample_product( + input_shape=[(2, 5, 6, 1)], + spec=[ + [stax.Conv(3, (2, 2))], + [stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)], + ], + ) def testSerialComposeLayersShape(self, input_shape, spec): init_fun, apply_fun = stax.serial(*spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_input_shape={input_shape}", - "input_shape": input_shape} - for input_shape in [(3, 4), (2, 5, 6, 1)])) + @jtu.sample_product(input_shape=[(3, 4), (2, 5, 6, 1)]) def testDropoutShape(self, input_shape): init_fun, apply_fun = stax.Dropout(0.9) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_input_shape={input_shape}", - "input_shape": input_shape} - for input_shape in [(3, 4), (2, 5, 6, 1)])) + @jtu.sample_product(input_shape=[(3, 4), (2, 5, 6, 1)]) def testFanInSum(self, input_shape): init_fun, apply_fun = stax.FanInSum _CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape]) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_inshapes={input_shapes}_axis={axis}", - "input_shapes": input_shapes, "axis": axis} + @jtu.sample_product( + [dict(input_shapes=input_shapes, axis=axis) for input_shapes, axis in [ ([(2, 3), (2, 1)], 1), ([(2, 3), (2, 1)], -1), ([(1, 2, 4), (1, 1, 4)], 1), - ])) + ] + ], + ) def testFanInConcat(self, input_shapes, axis): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes) diff --git a/tests/xmap_test.py b/tests/xmap_test.py index af9836cf1..bffa645f3 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -46,7 +46,7 @@ from jax._src.nn import initializers as nn_initializers from jax._src.lib import xla_bridge from jax._src.lib import xla_client from jax._src.lib import xla_extension_version -from jax._src.util import curry, unzip2, prod, safe_zip +from jax._src.util import unzip2, prod, safe_zip from jax._src.lax import parallel as lax_parallel from jax._src.lax.parallel import pgather from jax.interpreters import batching, pxla @@ -876,13 +876,12 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase): class NamedNumPyTest(XMapTestCase): - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{reduction.__name__}_axes={axes}_i={mapped_axis}", - "reduction": reduction, "axes": axes, "mapped_axis": mapped_axis} - for reduction in (jnp.sum, jnp.max, jnp.min, jnp.mean, jnp.var, jnp.std, - jscipy.special.logsumexp) - for axes in (0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0)) - for mapped_axis in range(3))) + @jtu.sample_product( + reduction=(jnp.sum, jnp.max, jnp.min, jnp.mean, jnp.var, jnp.std, + jscipy.special.logsumexp), + axes=(0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0)), + mapped_axis=range(3), + ) def testReductions(self, reduction, axes, mapped_axis): axes_t = axes if isinstance(axes, tuple) else (axes,) ref_red = partial(reduction, @@ -901,25 +900,15 @@ class NamedNumPyTest(XMapTestCase): class NamedRandomTest(XMapTestCase): - @curry - def parameterize_by_sampler(extra, f, subset): - if extra is None: - extra = [("", {})] - else: - extra = list(extra) - subset_fn = jtu.cases_from_list if subset else lambda x: x - return parameterized.named_parameters(subset_fn( - {"testcase_name": name + extra_name, "distr_sample": sample, **extra_kwargs} - for name, sample in [ - ("Uniform", jax.random.uniform), - ("Normal", jax.random.normal), - ("Bernoulli", partial(jax.random.bernoulli, p=0.5)), - ("TruncatedNormal", partial(jax.random.truncated_normal, lower=-2, upper=2)), - ] - for extra_name, extra_kwargs in extra))(f) + SAMPLERS = [ + ("Uniform", jax.random.uniform), + ("Normal", jax.random.normal), + ("Bernoulli", partial(jax.random.bernoulli, p=0.5)), + ("TruncatedNormal", partial(jax.random.truncated_normal, lower=-2, upper=2)), + ] - @parameterize_by_sampler(None, subset=False) - def testSamplerSharding(self, distr_sample): + @parameterized.parameters(*SAMPLERS) + def testSamplerSharding(self, distr_name, distr_sample): def sample(shape, map_size): return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=shape), in_axes=(), out_axes=[None, 'i', ...], axis_sizes={'i': map_size})() @@ -931,12 +920,15 @@ class NamedRandomTest(XMapTestCase): with self.assertRaisesRegex(ValueError, error): sample(NamedShape(3, i=4), 5) - @parameterize_by_sampler( - ((f"_mesh={mesh}_resources={sorted(axis_resources.items())}", - {"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)}) - for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True) + @jtu.sample_product( + [dict(distr_name=name, distr_sample=sample) + for name, sample in SAMPLERS], + [dict(axis_resources=tuple(axis_resources.items()), mesh=tuple(mesh)) + for axis_resources, mesh in schedules({'i': 4, 'j': 6})], + ) @jtu.with_mesh_from_kwargs - def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh): + def testSamplerResourceIndependence(self, distr_name, distr_sample, + axis_resources, mesh): def sample(axis_resources): return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)), in_axes=(), out_axes=['i', 'j', ...], axis_sizes={'i': 4, 'j': 6}, @@ -965,13 +957,12 @@ class NamedNNTest(XMapTestCase): with self.assertRaisesRegex(ValueError, "to match the size of axis i, but 3 != 5"): f(jnp.ones(5, dtype='int32')) - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_map_in={map_in}_map_out={map_out}_fan={fan}_distr={distr}", - "map_in": map_in, "map_out": map_out, "fan": fan, - "distr": distr} - for map_in, map_out in [(True, False), (False, True), (True, True)] - for fan in ['fan_in', 'fan_out', 'fan_avg'] - for distr in ['uniform', 'normal', 'truncated_normal'])) + @jtu.sample_product( + [dict(map_in=map_in, map_out=map_out) + for map_in, map_out in [(True, False), (False, True), (True, True)]], + fan=['fan_in', 'fan_out', 'fan_avg'], + distr=['uniform', 'normal', 'truncated_normal'], + ) def testVarianceScaling(self, map_in, map_out, fan, distr): shape = (80, 50, 7) fan_in, fan_out = nn_initializers._compute_fans(NamedShape(*shape), 0, 1)