Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.

Delete jtu.cases_from_list.
This commit is contained in:
Peter Hawkins 2022-10-12 13:51:11 +00:00
parent 9bb2c999d6
commit 72f4f389be
18 changed files with 1331 additions and 1835 deletions

View File

@ -658,20 +658,6 @@ def assert_dot_precision(expected_precision, fun, *args):
assert precision == expected_precision, msg
_CACHED_INDICES: Dict[int, Sequence[int]] = {}
def cases_from_list(xs):
xs = list(xs)
n = len(xs)
k = min(n, FLAGS.num_generated_cases)
# Random sampling for every parameterized test is expensive. Do it once and
# cache the result.
indices = _CACHED_INDICES.get(n)
if indices is None:
rng = npr.RandomState(42)
_CACHED_INDICES[n] = indices = rng.permutation(n)
return [xs[i] for i in indices[:k]]
def cases_from_gens(*gens):
sizes = [1, 3, 10]
cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1
@ -703,14 +689,20 @@ def named_cases_from_sampler(gen):
yield case
# Random sampling for every parameterized test is expensive. Do it once and
# cache the result.
@functools.lru_cache(maxsize=None)
def _choice(n, m):
rng = np.random.RandomState(42)
return rng.choice(n, size=m, replace=False)
def sample_product_testcases(*args, **kw):
"""Non-decorator form of sample_product."""
args = [list(arg) for arg in args]
kw = [(k, list(v)) for k, v in kw.items()]
n = prod(len(a) for a in args) * prod(len(v) for _, v in kw)
rng = np.random.RandomState(42)
testcases = []
for i in rng.choice(n, size=min(n, FLAGS.num_generated_cases), replace=False):
for i in _choice(n, min(n, FLAGS.num_generated_cases)):
testcase = {}
for a in args:
testcase.update(a[i % len(a)])

View File

@ -21,7 +21,6 @@ import unittest
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import ad_checkpoint
@ -200,12 +199,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertIsNotNone(f_tf.get_concrete_function())
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_dtype={dtype.__name__}_function={with_function}",
dtype=dtype,
with_function=with_function)
for dtype in [np.int64, np.float64]
for with_function in [True, False]))
@jtu.sample_product(
dtype=[np.int64, np.float64],
with_function=[True, False],
)
def test_converts_64bit(self, dtype=np.int64, with_function=False):
if not config.jax_enable_x64:
self.skipTest("requires x64 mode")
@ -251,10 +248,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
self.ConvertAndCompare(f_jax, 0.7)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_gradients_disabled(self, with_function=False):
f_tf = jax2tf.convert(jnp.tan, with_gradient=False)
if with_function:
@ -270,10 +264,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
y = f_tf(x)
_ = tape.gradient(y, x)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_gradients(self, with_function=True):
def f(x, y):
return x * x, x * y
@ -291,10 +282,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(5., tape.gradient(v, x))
self.assertAllClose(4., tape.gradient(v, y))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_gradients_pytree(self, with_function=True):
def f(xy: Tuple[float, float]) -> Dict[str, float]:
x, y = xy
@ -368,10 +356,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(grad_jax.b, grad_tf[1])
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_gradients_with_ordered_dict_input(self, with_function=True):
def f(inputs):
out = 0.0
@ -394,10 +379,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy())
self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy())
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_gradients_with_custom_jvp(self, with_function=True):
"""Check gradients, for a function with custom JVP."""
@jax.custom_jvp
@ -428,10 +410,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(4. * 4., y)
self.assertAllClose(3. * 4., tape.gradient(y, x))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_gradients_with_custom_vjp(self, with_function=True):
"""Check gradients, for a function with custom VJP."""
@jax.custom_vjp
@ -498,10 +477,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.int32),
d_dx_tf.numpy())
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_gradients_unused_argument_readme(self, with_function=False):
# x1 and x3 are not used. x3 has integer type.
def fn(x0, x1, x2, x3):
@ -546,10 +522,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
self.assertAllClose(g_jax2tf[3].numpy(), np.int32(0))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_gradients_int_argument(self, with_function=False):
# https://github.com/google/jax/issues/6975
# Also issue #6975.
@ -875,9 +848,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
jax2tf.convert(outer)(2.)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{transform}", transform=transform)
for transform in ["jit", "jvp", "grad", "vmap"]))
@jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"])
def test_convert_under_transform_error(self, transform="vmap"):
def outer(y):
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
@ -886,9 +857,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
ValueError, "convert must be used outside all JAX transformations"):
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{transform}", transform=transform)
for transform in ["jit", "jvp", "grad", "vmap"]))
@jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"])
def test_convert_under_transform_error_non_tracer(self, transform="vmap"):
def outer(y):
sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg
@ -997,10 +966,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(tf_fn(tf.constant(1.375, tf.bfloat16)).numpy(),
jnp.bfloat16(2.750))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_kwargs(self, with_function=True):
# Re: https://github.com/google/jax/issues/6791
def f_jax(*, x):
@ -1012,10 +978,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f_tf(x=np.zeros(3, dtype=np.float32)), # Call with kwargs.
np.zeros((), dtype=np.float32))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_grad_kwargs(self, with_function=False):
# Re: https://github.com/google/jax/issues/6791
x = (np.zeros(3, dtype=np.float32),

View File

@ -294,29 +294,19 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
f_jax = jax.jit(lambda i: params[i])
self.ConvertAndCompare(f_jax, indices)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
for f_jax in REDUCE))
@jtu.sample_product(f_jax=REDUCE)
def test_reduce_ops_with_numerical_input(self, f_jax):
values = np.array([1, 2, 3], dtype=np.float32)
self.ConvertAndCompare(f_jax, values)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_{op}", op=op) for op in (
"add", "max", "min", "multiply", "set"
)))
@jtu.sample_product(op=["add", "max", "min", "multiply", "set"])
def test_scatter_static(self, op):
values = np.ones((5, 6), dtype=np.float32)
update = np.float32(6.)
f_jax = jax.jit(lambda v, u: getattr(v.at[::2, 3:], op)(u))
self.ConvertAndCompare(f_jax, values, update)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
for f_jax in REDUCE))
@jtu.sample_product(f_jax=REDUCE)
def test_reduce_ops_with_boolean_input(self, f_jax):
values = np.array([True, False, True], dtype=np.bool_)
self.ConvertAndCompare(f_jax, values)

View File

@ -834,10 +834,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
res_jax,
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
@jtu.sample_product(with_function=[False, True])
def test_grad_int(self, with_function=True):
# https://github.com/google/jax/issues/7093
# Also issue #6975.

View File

@ -14,7 +14,7 @@
import unittest
from absl.testing import absltest, parameterized
from absl.testing import absltest
import jax
from jax.config import config
@ -67,16 +67,12 @@ class DLPackTest(jtu.JaxTestCase):
if jtu.device_under_test() == "tpu":
self.skipTest("DLPack not supported on TPU")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_take_ownership={}_gpu={}".format(
jtu.format_shape_dtype_string(shape, dtype),
take_ownership, gpu),
"shape": shape, "dtype": dtype, "take_ownership": take_ownership,
"gpu": gpu}
for shape in all_shapes
for dtype in dlpack_dtypes
for take_ownership in [False, True]
for gpu in [False, True]))
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
take_ownership=[False, True],
gpu=[False, True],
)
@jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
rng = jtu.rand_default(self.rng())
@ -95,12 +91,10 @@ class DLPackTest(jtu.JaxTestCase):
"DLPack tensor may be consumed at most once",
lambda: jax.dlpack.from_dlpack(dlpack))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in dlpack_dtypes))
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
)
@unittest.skipIf(not tf, "Test requires TensorFlow")
@jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973
def testTensorFlowToJax(self, shape, dtype):
@ -121,12 +115,10 @@ class DLPackTest(jtu.JaxTestCase):
y = jax.dlpack.from_dlpack(dlpack)
self.assertAllClose(np, y)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in dlpack_dtypes))
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
)
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testJaxToTensorFlow(self, shape, dtype):
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
@ -145,12 +137,10 @@ class DLPackTest(jtu.JaxTestCase):
y = tf.experimental.dlpack.from_dlpack(dlpack)
self.assertAllClose(np, y.numpy())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in torch_dtypes))
@jtu.sample_product(
shape=all_shapes,
dtype=torch_dtypes,
)
@unittest.skipIf(not torch, "Test requires PyTorch")
def testTorchToJax(self, shape, dtype):
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
@ -177,12 +167,10 @@ class DLPackTest(jtu.JaxTestCase):
xla_client._xla.dlpack_managed_tensor_to_buffer(
y, client)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in torch_dtypes))
@jtu.sample_product(
shape=all_shapes,
dtype=torch_dtypes,
)
@unittest.skipIf(not torch, "Test requires PyTorch")
def testJaxToTorch(self, shape, dtype):
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
@ -194,12 +182,10 @@ class DLPackTest(jtu.JaxTestCase):
y = torch.utils.dlpack.from_dlpack(dlpack)
self.assertAllClose(np, y.cpu().numpy())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in torch_dtypes))
@jtu.sample_product(
shape=all_shapes,
dtype=torch_dtypes,
)
@unittest.skipIf(numpy_version < (1, 22, 0), "Requires numpy 1.22 or newer")
@jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973
def testNumpyToJax(self, shape, dtype):
@ -208,12 +194,10 @@ class DLPackTest(jtu.JaxTestCase):
x_jax = jnp.from_dlpack(x_np)
self.assertAllClose(x_np, x_jax)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in torch_dtypes))
@jtu.sample_product(
shape=all_shapes,
dtype=torch_dtypes,
)
@unittest.skipIf(numpy_version < (1, 22, 0), "Requires numpy 1.22 or newer")
@jtu.skip_on_devices("gpu")
def testJaxToNumpy(self, shape, dtype):
@ -230,12 +214,10 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
if jtu.device_under_test() != "gpu":
self.skipTest("__cuda_array_interface__ is only supported on GPU")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in dlpack_dtypes))
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
)
@unittest.skipIf(not cupy, "Test requires CuPy")
def testJaxToCuPy(self, shape, dtype):
if dtype == jnp.bfloat16:

View File

@ -24,7 +24,6 @@ import unittest
from unittest import skip, SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import ad_checkpoint
@ -514,12 +513,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertEqual(
len(local_devices()), len(re.findall(r"112", testing_stream.output)))
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
@jtu.sample_product(with_jit=[True, False])
def test_tap_pytree(self, with_jit=False):
def func(x, what=""):
"""Returns some pytrees depending on x"""
@ -551,12 +545,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait() # Wait for receivers to be done
self.assertEqual(3, tap_count)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_concurrent_{concurrent}",
concurrent=concurrent)
for concurrent in [True, False]))
@jtu.sample_product(concurrent=[True, False])
def test_tap_multiple(self, concurrent=False):
"""Call id_tap multiple times, concurrently or in sequence. """
if concurrent and jtu.device_under_test() in ["cpu", "gpu"]:
@ -628,12 +617,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
[t.start() for t in threads]
[t.join() for t in threads]
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
@jtu.sample_product(with_jit=[True, False])
def test_tap_cond(self, with_jit=False):
"""A conditional"""
@ -663,11 +647,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
where: end
4""", testing_stream.output)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
@jtu.sample_product(with_jit=[True, False])
def test_tap_while_cond(self, with_jit=False):
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
@ -741,12 +721,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
where: 3
3""", testing_stream.output)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
@jtu.sample_product(with_jit=[True, False])
def test_tap_scan_cond(self, with_jit=True):
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
@ -796,15 +771,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
[1 2 3]""", testing_stream.output)
testing_stream.reset()
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_shape_{shape}_dtype_{np.dtype(dtype).name}_nr_args={nr_args}",
shape=shape,
dtype=dtype,
nr_args=nr_args) for nr_args in [1, 2]
for shape in [(), (2,), (2, 3), (2, 3, 4)]
for dtype in jtu.dtypes.all))
@jtu.sample_product(
nr_args=[1, 2],
shape=[(), (2,), (2, 3), (2, 3, 4)],
dtype=jtu.dtypes.all,
)
def test_tap_jit_dtypes(self, nr_args=2, dtype=jnp.int16, shape=(2,)):
if dtype in (jnp.complex64, jnp.complex128, jnp.bool_):
raise SkipTest(f"host_callback not implemented for {dtype}.")
@ -1971,13 +1942,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
10"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_use_remat={use_remat}_{grad_func}_use_result={use_result}",
use_result=use_result, use_remat=use_remat, grad_func=grad_func)
for use_result in [True, False]
for grad_func in ["grad", "value_and_grad"]
for use_remat in ["old", "new", "none"]))
@jtu.sample_product(
use_result=[True, False],
grad_func=["grad", "value_and_grad"],
use_remat=["old", "new", "none"],
)
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
if use_remat == "old": raise SkipTest()
@ -2108,11 +2077,9 @@ class HostCallbackCallTest(jtu.JaxTestCase):
self.assertAllClose(2 * arg, fun(arg))
self.assertEqual(count[0], 1)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_{np.dtype(dtype).name}", dtype=dtype)
for dtype in jtu.dtypes.all
if dtype != np.bool_))
@jtu.sample_product(
dtype=[dtype for dtype in jtu.dtypes.all if dtype != np.bool_],
)
def test_call_types(self, dtype=np.float64):
def f_outside(x):

View File

@ -133,7 +133,7 @@ LAX_GRAD_OPS = [
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_complex_dtypes),
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes, tol={np.float64: 3e-5}),
dtypes=grad_float_dtypes, tol={np.float64: 5e-3}),
grad_test_spec(lax.logistic, nargs=1, order=2,
rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
@ -195,38 +195,43 @@ def check_grads_bilinear(f, args, order,
class LaxAutodiffTest(jtu.JaxTestCase):
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.name, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
"order": rec.order, "tol": rec.tol}
for shape_group in compatible_shapes
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op=rec.op, rng_factory=rec.rng_factory, order=rec.order, tol=rec.tol)],
shapes=[
shapes for shape_group in compatible_shapes
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
for dtype in rec.dtypes)
for rec in LAX_GRAD_OPS))
],
dtype=rec.dtypes,
)
for rec in LAX_GRAD_OPS
))
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
rng = rng_factory(self.rng())
if jtu.device_under_test() == "tpu" and op is lax.pow:
raise SkipTest("pow grad imprecise on tpu")
if jtu.device_under_test() == "tpu":
if op is lax.pow:
raise SkipTest("pow grad imprecise on tpu")
if op is lax.cos:
order = 1 # 2nd-order gradient is imprecise on TPU.
tol = jtu.join_tolerance(1e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": f"_{rec.op.__name__}_{special_value}",
"op": rec.op, "special_value": special_value, "tol": rec.tol}
for special_value in rec.values)
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op=rec.op, tol=rec.tol)],
special_value=rec.values,
)
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS
))
def testOpGradSpecialValue(self, op, special_value, tol):
check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
"from_dtype": from_dtype, "to_dtype": to_dtype}
for from_dtype, to_dtype in itertools.product(inexact_dtypes, repeat=2)))
@jtu.sample_product(
from_dtype=inexact_dtypes,
to_dtype=inexact_dtypes,
)
def testConvertElementTypeGrad(self, from_dtype, to_dtype):
rng = jtu.rand_default(self.rng())
tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
@ -237,12 +242,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
convert_element_type)
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(), (2, 3)]
for dtype in grad_float_dtypes))
@jtu.sample_product(
shape=[(), (2, 3)],
dtype=grad_float_dtypes,
)
def testClampGrad(self, shape, dtype):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
@ -253,15 +256,14 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2)
check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name,
num_arrs),
"dim": dim, "base_shape": base_shape, "dtype": dtype, "num_arrs": num_arrs}
for num_arrs in [3]
for dtype in float_dtypes
@jtu.sample_product(
[dict(base_shape=base_shape, dim=dim)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for dim in range(len(base_shape))))
for dim in range(len(base_shape))
],
num_arrs=[3],
dtype=float_dtypes,
)
def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs):
rng = jtu.rand_default(self.rng())
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
@ -270,21 +272,17 @@ class LaxAutodiffTest(jtu.JaxTestCase):
concatenate = lambda *args: lax.concatenate(args, dim)
check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding}
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides)
for lhs_shape, rhs_shape, all_strides in itertools.chain(
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
for b, i, j in itertools.product([2, 3], repeat=3)],
[((4, 2, 1), (3, 2, 1), [(1,)])])
for strides in all_strides
for dtype in float_dtypes
for padding in ["VALID", "SAME"]))
],
dtype=float_dtypes,
padding=["VALID", "SAME"],
)
def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding):
rng = jtu.rand_small(self.rng())
lhs = rng(lhs_shape, dtype)
@ -294,18 +292,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
"rhs_dilation={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dil, rhs_dil),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
"rhs_dil": rhs_dil}
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides,
padding=padding, lhs_dil=lhs_dil, rhs_dil=rhs_dil)
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
itertools.chain(
itertools.chain(
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
[((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
[(1, 1), (2, 1)], [(1, 1)])
@ -315,8 +306,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
for strides in all_strides
for rhs_dil in rhs_dils
for lhs_dil in lhs_dils
for dtype in float_dtypes
for padding in all_pads))
for padding in all_pads
],
dtype=float_dtypes,
)
def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dil, rhs_dil):
rng = jtu.rand_small(self.rng())
@ -328,19 +321,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
"rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
feature_group_count, batch_group_count),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
"rhs_dil": rhs_dil, "dimension_numbers": dim_nums,
"perms": perms, "feature_group_count": feature_group_count,
"batch_group_count": batch_group_count}
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides,
padding=padding, lhs_dil=lhs_dil, rhs_dil=rhs_dil,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)])
for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
([(b * batch_group_count, i * feature_group_count, 6, 7),
@ -354,13 +339,17 @@ class LaxAutodiffTest(jtu.JaxTestCase):
for strides in all_strides
for rhs_dil in rhs_dils
for lhs_dil in lhs_dils
for dtype in grad_inexact_dtypes
for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] +
([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
],
[dict(dimension_numbers=dim_nums, perms=perms)
for dim_nums, perms in [
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]))
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
],
dtype=grad_inexact_dtypes,
)
def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dil, rhs_dil, dimension_numbers,
perms, feature_group_count, batch_group_count):
@ -386,13 +375,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype)),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype}
for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)]
for dtype in float_dtypes))
@jtu.sample_product(
lhs_shape=[(2,), (3, 2)],
rhs_shape=[(2,), (2, 4)],
dtype=float_dtypes,
)
def testDotGrad(self, lhs_shape, rhs_shape, dtype):
rng = jtu.rand_default(self.rng())
tol = {np.float16: 1e-1, np.float32: 1e-4}
@ -407,14 +394,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
s = str(jax.make_jaxpr(pullback)(gresult))
assert "Precision.HIGHEST" in s
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
dimension_numbers),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"dimension_numbers": dimension_numbers}
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers)
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 2), (2, 4), (([1], [0]), ([], []))),
((3, 5), (2, 5), (([1], [1]), ([], []))),
@ -423,7 +405,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))),
((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))),
]
for dtype in float_dtypes))
],
dtype=float_dtypes,
)
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
dimension_numbers):
rng = jtu.rand_small(self.rng())
@ -438,46 +422,36 @@ class LaxAutodiffTest(jtu.JaxTestCase):
s = str(jax.make_jaxpr(pullback)(gresult))
assert "Precision.HIGHEST" in s
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
shape, np.dtype(dtype).name, broadcast_sizes),
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes}
for shape in [(), (2, 3)]
for dtype in float_dtypes
for broadcast_sizes in [(), (2,), (1, 2)]))
@jtu.sample_product(
shape=[(), (2, 3)],
dtype=float_dtypes,
broadcast_sizes=[(), (2,), (1, 2)],
)
def testBroadcastGrad(self, shape, dtype, broadcast_sizes):
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
jtu.format_shape_dtype_string(inshape, dtype),
outshape, broadcast_dimensions),
"inshape": inshape, "dtype": dtype, "outshape": outshape,
"dimensions": broadcast_dimensions}
@jtu.sample_product(
[dict(inshape=inshape, outshape=outshape, dimensions=broadcast_dimensions)
for inshape, outshape, broadcast_dimensions in [
([2], [2, 2], [0]),
([2], [2, 2], [1]),
([2], [2, 3], [0]),
([], [2, 3], []),
]
for dtype in float_dtypes))
],
dtype=float_dtypes,
)
def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions):
rng = jtu.rand_default(self.rng())
operand = rng(inshape, dtype)
broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_perm={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype),
permutation),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
"permutation": permutation}
for dtype in float_dtypes
@jtu.sample_product(
[dict(arg_shape=arg_shape, out_shape=out_shape, permutation=permutation)
for arg_shape, out_shape, permutation in [
[(3, 4), (12,), None],
[(2, 1, 4), (8,), None],
@ -488,23 +462,26 @@ class LaxAutodiffTest(jtu.JaxTestCase):
[(2, 1, 4), (8,), (2, 0, 1)],
[(2, 2, 4), (2, 8), (0, 2, 1)],
[(2, 2, 4), (2, 8), (2, 0, 1)],
]))
]
],
dtype=float_dtypes,
)
def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype):
rng = jtu.rand_default(self.rng())
operand = rng(arg_shape, dtype)
reshape = lambda x: lax.reshape(x, out_shape, permutation)
check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_pads={}"
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
"shape": shape, "dtype": dtype, "pads": pads}
for dtype in float_dtypes
@jtu.sample_product(
[dict(shape=shape, pads=pads)
for shape, paddings in [
[(), [()]],
((2, 3), [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]),
]
for pads in paddings))
for pads in paddings
],
dtype=float_dtypes,
)
def testPadGrad(self, shape, dtype, pads):
rng = jtu.rand_small(self.rng())
operand = rng(shape, dtype)
@ -546,14 +523,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
# self.assertEqual(result, 0.0)
self.assertAllClose(result, np.nan)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_predshape={}_argshapes={}".format(
jtu.format_shape_dtype_string(pred_shape, np.bool_),
jtu.format_shape_dtype_string(arg_shape, dtype)),
"pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype}
@jtu.sample_product(
[dict(arg_shape=arg_shape, pred_shape=pred_shape)
for arg_shape in [(), (3,), (2, 3)]
for pred_shape in ([(), arg_shape] if arg_shape else [()])
for dtype in float_dtypes))
],
dtype=float_dtypes,
)
def testSelectGrad(self, pred_shape, arg_shape, dtype):
rng = jtu.rand_default(self.rng())
pred = rng(pred_shape, np.bool_)
@ -562,13 +538,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, limit_indices, strides),
"shape": shape, "dtype": dtype, "starts": start_indices,
"limits": limit_indices, "strides": strides}
@jtu.sample_product(
[dict(shape=shape, starts=start_indices, limits=limit_indices,
strides=strides)
for shape, start_indices, limit_indices, strides in [
[(3,), (1,), (2,), None],
[(7,), (4,), (7,), None],
@ -581,44 +553,43 @@ class LaxAutodiffTest(jtu.JaxTestCase):
[(5, 3), (1, 1), (5, 3), (2, 1)],
[(3, 3, 5), (0, 2, 0), (3, 2, 5), (1, 2, 1)]
]
for dtype in float_dtypes))
],
dtype=float_dtypes,
)
def testSliceGrad(self, shape, dtype, starts, limits, strides):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
slice = lambda x: lax.slice(x, starts, limits, strides)
check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, size_indices),
"shape": shape, "dtype": dtype, "start_indices": start_indices,
"size_indices": size_indices}
@jtu.sample_product(
[dict(shape=shape, start_indices=start_indices, size_indices=size_indices)
for shape, start_indices, size_indices in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
for dtype in float_dtypes))
],
dtype=float_dtypes,
)
def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, update_shape),
"shape": shape, "dtype": dtype, "start_indices": start_indices,
"update_shape": update_shape}
@jtu.sample_product(
[dict(shape=shape, start_indices=start_indices, update_shape=update_shape)
for shape, start_indices, update_shape in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
for dtype in float_dtypes))
def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape):
],
dtype=float_dtypes,
)
def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices,
update_shape):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
update = rng(update_shape, dtype)
@ -664,28 +635,25 @@ class LaxAutodiffTest(jtu.JaxTestCase):
result2, _ = jax.value_and_grad(f, 0)(x, y)
self.assertAllClose(result1, result2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_perm={}".format(
jtu.format_shape_dtype_string(shape, dtype), perm),
"shape": shape, "dtype": dtype, "perm": perm}
@jtu.sample_product(
[dict(shape=shape, perm=perm)
for shape, perm in [
[(3, 4), (1, 0)],
[(3, 4), (0, 1)],
[(3, 4, 5), (2, 1, 0)],
[(3, 4, 5), (1, 0, 2)],
]
for dtype in float_dtypes))
],
dtype=float_dtypes,
)
def testTransposeGrad(self, shape, dtype, perm):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
transpose = lambda x: lax.transpose(x, perm)
check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_inshape={}_reducedims={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
"dims": dims, "rng_factory": rng_factory}
@jtu.sample_product(
[dict(init_val=init_val, op=op, dtype=dtype, rng_factory=rng_factory)
for init_val, op, dtypes, rng_factory in [
(0, lax.add, float_dtypes + jtu.dtypes.complex, jtu.rand_default),
(-np.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
@ -693,6 +661,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
(1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
]
for dtype in dtypes
],
[dict(shape=shape, dims=dims)
for shape, dims in [
[(), ()],
[(3, 4, 5), ()],
@ -702,7 +672,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
[(3, 4, 5), (0, 1, 2)],
[(3, 1), (1,)],
[(3, 0, 5), (1,)],
]))
]
],
)
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
rng = rng_factory(self.rng())
if jtu.device_under_test() == "tpu" and op is lax.mul:
@ -718,11 +690,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_reducedims={}"
.format(jtu.format_shape_dtype_string(shape, dtype), dims),
"shape": shape, "dtype": dtype, "dims": dims}
for dtype in grad_float_dtypes
@jtu.sample_product(
[dict(shape=shape, dims=dims)
for shape, dims in [
[(3, 4, 5), ()],
[(3, 4, 5), (0,)],
@ -731,7 +700,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
[(3, 4, 5), (0, 1, 2)],
[(3, 1), (1,)],
[(3, 0, 5), (1,)],
]))
]
],
dtype=grad_float_dtypes,
)
def testReducePairGrad(self, shape, dtype, dims):
rng = jtu.rand_default(self.rng(), scale=1)
tol = {np.float32: 1e-2, np.float64: 1e-4}
@ -742,20 +714,16 @@ class LaxAutodiffTest(jtu.JaxTestCase):
reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims)
check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
"_basedilation={}_windowdilation={}")
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
strides, padding, base_dilation, window_dilation),
"op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
"dims": dims, "strides": strides, "padding": padding,
"base_dilation": base_dilation, "window_dilation": window_dilation,
"rng_factory": rng_factory}
@jtu.sample_product(
[dict(init_val=init_val, op=op, dtype=dtype, rng_factory=rng_factory,
shape=shape, dims=dims, strides=strides, padding=padding,
base_dilation=base_dilation, window_dilation=window_dilation)
for init_val, op, dtypes, rng_factory in [
(0, lax.add, grad_float_dtypes, jtu.rand_small),
(-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
(np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
]
for dtype in dtypes
for shape, dims, strides, padding, base_dilation, window_dilation in (
itertools.chain(
itertools.product(
@ -772,7 +740,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
[(1, 1, 1, 1)] + ([(2, 1, 3, 2)]),
[(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else []))))
for dtype in dtypes))
],
)
@jtu.ignore_warning(category=UserWarning,
message="Using reduced precision for gradient.*")
def testReduceWindowGrad(
@ -808,20 +777,20 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
eps)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_shape={}_axis={}_reverse={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
reverse),
"op": op, "shape": shape, "dtype": dtype,
"axis": axis, "reverse": reverse}
@jtu.sample_product(
[dict(op=op, dtype=dtype)
for op, types in [
(lax.cumsum, [np.float32, np.float64]),
(lax.cumprod, [np.float32, np.float64]),
]
for dtype in types
],
[dict(shape=shape, axis=axis)
for shape in [[10], [3, 4, 5]]
for axis in range(len(shape))
for reverse in [False, True]))
],
reverse=[False, True],
)
def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse):
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small)
@ -831,33 +800,30 @@ class LaxAutodiffTest(jtu.JaxTestCase):
# TODO(b/205052657): enable more tests when supported
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_isstable={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
"shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable}
for dtype in [np.float32]
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(5,), (5, 7)]
for axis in [len(shape) - 1]
for is_stable in [False, True]))
],
dtype=[np.float32],
is_stable=[False, True],
)
def testSortGrad(self, shape, dtype, axis, is_stable):
rng = jtu.rand_default(self.rng())
rng = jtu.rand_unique_int(self.rng())
operand = rng(shape, dtype)
sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)
# TODO(b/205052657): enable more tests when supported
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format(
jtu.format_shape_dtype_string(shape, key_dtype),
jtu.format_shape_dtype_string(shape, val_dtype),
axis, is_stable),
"shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype,
"axis": axis, "is_stable": is_stable}
for key_dtype in [np.float32]
for val_dtype in [np.float32]
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(3,), (5, 3)]
for axis in [len(shape) - 1]
for is_stable in [False, True]))
],
key_dtype=[np.float32],
val_dtype=[np.float32],
is_stable=[False, True],
)
def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable):
rng = jtu.rand_default(self.rng())
# This test relies on the property that wherever keys are tied, values are
@ -873,30 +839,28 @@ class LaxAutodiffTest(jtu.JaxTestCase):
fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"shape": shape, "dtype": dtype, "k": k}
for dtype in [np.float32,]
for shape in [(4,), (5, 5), (2, 1, 4)]
for k in [1, 3]))
@jtu.sample_product(
dtype=[np.float32,],
shape=[(4,), (5, 5), (2, 1, 4)],
k=[1, 3],
)
def testTopKGrad(self, shape, dtype, k):
flat_values = np.arange(prod(shape), dtype=dtype)
values = self.rng().permutation(flat_values).reshape(shape)
fun = lambda vs: lax.top_k(vs, k=k)[0]
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
"shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes}
for dtype in float_dtypes
@jtu.sample_product(
[dict(shape=shape, idxs=idxs, axes=axes)
for shape, idxs, axes in [
[(3, 4, 5), (np.array([0, 2, 1]),), (0,)],
[(3, 4, 5), (np.array([-1, -2]),), (0,)],
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)],
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)],
]))
]
],
dtype=float_dtypes,
)
@jax.numpy_rank_promotion('allow') # Test explicitly exercises implicit rank promotion.
def testIndexTakeGrad(self, shape, dtype, idxs, axes):
rng = jtu.rand_default(self.rng())
@ -904,16 +868,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
index_take = lambda src: lax.index_take(src, idxs, axes)
check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
f"_idxs={jtu.format_shape_dtype_string(idxs.shape, idxs.dtype)}"
f"_dnums={dnums}_slice_sizes={slice_sizes}_mode={mode}"
f"_iteration={iteration}",
"shape": shape, "dtype": dtype, "idxs_shape": idxs.shape,
"idxs_dtype": idxs.dtype, "dnums": dnums, "slice_sizes": slice_sizes,
"max_idx": max_idx, "mode": mode}
for dtype in grad_float_dtypes
@jtu.sample_product(
[dict(shape=shape, idxs_shape=idxs.shape, idxs_dtype=idxs.dtype,
dnums=dnums, slice_sizes=slice_sizes, max_idx=max_idx)
for shape, idxs, dnums, slice_sizes, max_idx in [
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
@ -925,10 +882,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3), 3),
]
for mode in ["clip", "fill", "promise_in_bounds"]
for iteration in range(5)))
],
dtype=grad_float_dtypes,
mode=["clip", "fill", "promise_in_bounds"],
iteration=range(5),
)
def testGatherGrad(self, shape, dtype, idxs_shape, idxs_dtype, dnums,
slice_sizes, mode, max_idx):
slice_sizes, mode, max_idx, iteration):
rng = jtu.rand_default(self.rng())
if mode == "promise_in_bounds":
rng_idx = jtu.rand_int(self.rng(), high=max_idx)
@ -945,16 +905,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
x = rng(shape, dtype)
check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(arg_shape, dtype)}"
f"_idxs={jtu.format_shape_dtype_string(idxs.shape, idxs.dtype)}"
f"_update={update_shape}_dnums={dnums}_mode={mode}"
f"_iteration={iteration}",
"arg_shape": arg_shape, "dtype": dtype, "idxs_shape": idxs.shape,
"idxs_dtype": idxs.dtype, "update_shape": update_shape, "dnums": dnums,
"max_idx": max_idx, "mode": mode}
for dtype in grad_float_dtypes
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs_shape=idxs.shape, idxs_dtype=idxs.dtype,
dnums=dnums, update_shape=update_shape, max_idx=max_idx)
for arg_shape, idxs, update_shape, dnums, max_idx in [
((5,), np.array([[0], [2]]), (2,),
lax.ScatterDimensionNumbers(update_window_dims=(),
@ -969,10 +922,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,)), 3),
]
for mode in ["clip", "fill", "promise_in_bounds"]
for iteration in range(5)))
],
dtype=grad_float_dtypes,
mode=["clip", "fill", "promise_in_bounds"],
iteration=range(5),
)
def testScatterAddGrad(self, arg_shape, dtype, idxs_shape, idxs_dtype,
update_shape, dnums, max_idx, mode):
update_shape, dnums, max_idx, mode, iteration):
rng = jtu.rand_default(self.rng())
if mode == "promise_in_bounds":
rng_idx = jtu.rand_int(self.rng(), high=max_idx)
@ -987,13 +943,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
y = rng(update_shape, dtype)
check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
idxs, update_shape, dnums),
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
"update_shape": update_shape, "dnums": dnums, "rng_idx_factory": rng_idx_factory}
for dtype in grad_float_dtypes
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums,
update_shape=update_shape, max_idx=max_idx)
for arg_shape, idxs, update_shape, dnums, max_idx in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
@ -1005,13 +957,15 @@ class LaxAutodiffTest(jtu.JaxTestCase):
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,)), 3),
]
# Scatters with conflicting indices are not deterministic on GPU, so we
# use indices that do not collide.
for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)]))
],
dtype=grad_float_dtypes,
)
def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
rng_idx_factory):
max_idx):
# Scatters with conflicting indices are not deterministic on GPU, so we
# use indices that do not collide.
rng_idx = jtu.rand_unique_int(self.rng(), high=max_idx)
rng = jtu.rand_default(self.rng())
rng_idx = rng_idx_factory(self.rng())
idxs = rng_idx(idxs.shape, idxs.dtype)
scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
x = rng(arg_shape, dtype)
@ -1028,13 +982,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
idxs, update_shape, dnums),
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
"update_shape": update_shape, "dnums": dnums}
for dtype in grad_float_dtypes
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums,
update_shape=update_shape)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
@ -1045,7 +995,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]))
]
],
dtype=grad_float_dtypes,
)
def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums):
rng = jtu.rand_default(self.rng())
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
@ -1055,13 +1008,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
y = rng(update_shape, dtype)
check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
idxs, update_shape, dnums),
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
"update_shape": update_shape, "dnums": dnums}
for dtype in grad_float_dtypes
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums,
update_shape=update_shape)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
@ -1072,7 +1021,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]))
]
],
dtype=grad_float_dtypes,
)
def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums):
rng = jtu.rand_default(self.rng())
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))

View File

@ -137,7 +137,7 @@ if numpy_version >= (1, 22): # where keyword added in numpy 1.22
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
inexact=True, tolerance={np.float16: 3e-3}),
op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
]
@ -148,7 +148,7 @@ JAX_REDUCER_NO_DTYPE_RECORDS = [
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
inexact=True, tolerance={jnp.bfloat16: 2e-2}),
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("nanmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
@ -179,23 +179,25 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
for a in out]
return f
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis,
"None" if out_dtype is None else np.dtype(out_dtype).name, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
[dict(shape=shape, axis=axis, dtype=dtype)
for shape in rec.shapes
for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_RECORDS))
def testReducer(self, np_op, jnp_op, rng_factory, shape, dtype, out_dtype,
if jtu.is_valid_shape(shape, dtype)
],
out_dtype=[out_dtype for out_dtype in [None] + rec.dtypes
if out_dtype not in unsigned_dtypes],
keepdims=[False, True],
)
for rec in JAX_REDUCER_RECORDS
))
def testReducer(self, name, rng_factory, shape, dtype, out_dtype,
axis, keepdims, inexact):
np_op = getattr(np, name)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
@jtu.ignore_warning(category=np.ComplexWarning)
@jtu.ignore_warning(category=RuntimeWarning,
@ -223,23 +225,26 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,
tolerance=rec.tolerance)],
[dict(shape=shape, axis=axis, dtype=dtype)
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_NO_DTYPE_RECORDS))
def testReducerNoDtype(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, inexact):
if jtu.is_valid_shape(shape, dtype)
],
keepdims=[False, True],
)
for rec in JAX_REDUCER_NO_DTYPE_RECORDS
))
def testReducerNoDtype(self, name, rng_factory, shape, dtype, axis,
keepdims, inexact, tolerance):
np_op = getattr(np, name)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
is_bf16_nan_test = (dtype == jnp.bfloat16 and
rng_factory.__name__ == 'rand_some_nan')
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
@ -255,25 +260,28 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
args_maker = lambda: [rng(shape, dtype)]
tol = {np.float16: 0.002}
tol = jtu.join_tolerance({np.float16: 0.002},
tolerance or jtu.default_tolerance())
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
[dict(shape=shape, axis=axis, dtype=dtype)
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for initial in [0, 1] for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_INITIAL_RECORDS))
def testReducerInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
if jtu.is_valid_shape(shape, dtype)
],
initial=[0, 1],
keepdims=[False, True],
)
for rec in JAX_REDUCER_INITIAL_RECORDS
))
def testReducerInitial(self, name, rng_factory, shape, dtype, axis,
keepdims, initial, inexact):
np_op = getattr(np, name)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
@ -295,25 +303,27 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_promote_integers={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, promote_integers),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact,
"promote_integers": promote_integers}
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
[dict(shape=shape, axis=axis, dtype=dtype)
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for initial in [0, 1] for keepdims in [False, True]
for promote_integers in [True, False]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_PROMOTE_INT_RECORDS))
def testReducerPromoteInt(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
if jtu.is_valid_shape(shape, dtype)
],
initial=[0, 1],
keepdims=[False, True],
promote_integers=[False, True],
)
for rec in JAX_REDUCER_PROMOTE_INT_RECORDS
))
def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis,
keepdims, initial, inexact, promote_integers):
np_op = getattr(np, name)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
is_bf16_nan_test = (dtype == jnp.bfloat16 and
rng_factory.__name__ == 'rand_some_nan')
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning)
@ -338,21 +348,22 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
[dict(shape=shape, axis=axis)
for shape in rec.shapes if np.prod(shape) == 0
for dtype in rec.dtypes
for keepdims in [False, True]
for axis in range(-len(shape), len(shape)) if shape[axis] >= 1)
for rec in JAX_REDUCER_INITIAL_RECORDS))
def testReducerNoInitialZeroDims(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
for axis in range(-len(shape), len(shape)) if shape[axis] >= 1
],
dtype=rec.dtypes,
keepdims=[False, True],
)
for rec in JAX_REDUCER_INITIAL_RECORDS
))
def testReducerNoInitialZeroDims(self, name, rng_factory, shape, dtype, axis,
keepdims, inexact):
np_op = getattr(np, name)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
@ -374,23 +385,24 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial,
jtu.format_shape_dtype_string(whereshape, bool)),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
[dict(shape=shape, axis=axis, dtype=dtype, whereshape=whereshape)
for shape in rec.shapes for dtype in rec.dtypes
for whereshape in _compatible_shapes(shape)
for axis in list(range(-len(shape), len(shape))) + [None]
for initial in [0, 1] for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_INITIAL_RECORDS))
def testReducerWhere(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
if jtu.is_valid_shape(shape, dtype)
for whereshape in _compatible_shapes(shape)
],
initial=[0, 1],
keepdims=[False, True],
)
for rec in JAX_REDUCER_INITIAL_RECORDS
))
def testReducerWhere(self, name, rng_factory, shape, dtype, axis,
keepdims, initial, inexact, whereshape):
np_op = getattr(np, name)
jnp_op = getattr(jnp, name)
if (shape in [()] + scalar_shapes and
dtype in [jnp.int16, jnp.uint16] and
jnp_op in [jnp.min, jnp.max]):
@ -417,23 +429,23 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims,
jtu.format_shape_dtype_string(whereshape, bool)),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for whereshape in _compatible_shapes(shape)
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS))
def testReducerWhereNoInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, inexact, whereshape):
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,
tol=rec.tolerance)],
[dict(shape=shape, axis=axis, dtype=dtype, whereshape=whereshape)
for shape in rec.shapes for dtype in rec.dtypes
for whereshape in _compatible_shapes(shape)
for axis in list(range(-len(shape), len(shape))) + [None]
if jtu.is_valid_shape(shape, dtype)
],
keepdims=[False, True],
) for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS
))
def testReducerWhereNoInitial(self, name, rng_factory, shape, dtype, axis,
keepdims, inexact, whereshape, tol):
np_op = getattr(np, name)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
@ -458,7 +470,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
def testReductionOfOutOfBoundsAxis(self): # Issue 888
@ -510,19 +522,14 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
rtol=tol, atol=tol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
.format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims),
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
"ddof": ddof, "keepdims": keepdims}
for shape in [(5,), (10, 5)]
for dtype in all_dtypes
for out_dtype in inexact_dtypes
for axis in [None, 0, -1]
for ddof in [0, 1, 2]
for keepdims in [False, True]))
@jtu.sample_product(
shape=[(5,), (10, 5)],
dtype=all_dtypes,
out_dtype=inexact_dtypes,
axis=[None, 0, -1],
ddof=[0, 1, 2],
keepdims=[False, True],
)
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@ -547,19 +554,14 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
atol=tol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
.format(shape, dtype, out_dtype, axis, ddof, keepdims),
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
"ddof": ddof, "keepdims": keepdims}
for shape in [(5,), (10, 5)]
for dtype in all_dtypes
for out_dtype in inexact_dtypes
for axis in [None, 0, -1]
for ddof in [0, 1, 2]
for keepdims in [False, True]))
@jtu.sample_product(
shape=[(5,), (10, 5)],
dtype=all_dtypes,
out_dtype=inexact_dtypes,
axis=[None, 0, -1],
ddof=[0, 1, 2],
keepdims=[False, True],
)
def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
rng = jtu.rand_some_nan(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@ -594,22 +596,20 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
z = jax.grad(jnp.nanstd)(x)
self.assertEqual(jnp.isnan(z).sum(), 0)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format(
shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights),
"shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype,"rowvar": rowvar, "ddof": ddof,
"bias": bias, "fweights": fweights, "aweights": aweights}
for shape in [(5,), (10, 5), (5, 10)]
for dtype in all_dtypes
for y_dtype in [None, dtype]
for rowvar in [True, False]
for y_shape in _get_y_shapes(y_dtype, shape, rowvar)
for bias in [True, False]
for ddof in [None, 2, 3]
for fweights in [True, False]
for aweights in [True, False]))
@jtu.sample_product(
[dict(shape=shape, dtype=dtype, y_dtype=y_dtype, rowvar=rowvar,
y_shape=y_shape)
for shape in [(5,), (10, 5), (5, 10)]
for dtype in all_dtypes
for y_dtype in [None, dtype]
for rowvar in [True, False]
for y_shape in _get_y_shapes(y_dtype, shape, rowvar)
],
bias=[True, False],
ddof=[None, 2, 3],
fweights=[True, False],
aweights=[True, False],
)
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights):
rng = jtu.rand_default(self.rng())

View File

@ -15,7 +15,6 @@
from functools import partial
import unittest
from absl.testing import parameterized
from absl.testing import absltest
import numpy as np
import scipy.sparse.linalg
@ -93,15 +92,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
M = None
return M
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
for shape in [(4, 4), (7, 7)]
for dtype in [np.float64, np.complex128]
for preconditioner in [None, 'identity', 'exact', 'random']))
@jtu.sample_product(
shape=[(4, 4), (7, 7)],
dtype=[np.float64, np.complex128],
preconditioner=[None, 'identity', 'exact', 'random'],
)
def test_cg_against_scipy(self, shape, dtype, preconditioner):
if not config.x64_enabled:
raise unittest.SkipTest("requires x64 mode")
@ -132,12 +127,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
args_maker,
tol=1e-6)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
"shape": shape, "dtype": dtype}
for shape in [(2, 2)]
for dtype in float_types + complex_types))
@jtu.sample_product(
shape=[(2, 2)],
dtype=float_types + complex_types,
)
def test_cg_as_solve(self, shape, dtype):
rng = jtu.rand_default(self.rng())
@ -219,16 +212,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertTrue(dtypes.is_weakly_typed(x))
# BICGSTAB
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
for shape in [(5, 5)]
for dtype in [np.float64, np.complex128]
for preconditioner in [None, 'identity', 'exact', 'random']
))
@jtu.sample_product(
shape=[(5, 5)],
dtype=[np.float64, np.complex128],
preconditioner=[None, 'identity', 'exact', 'random'],
)
def test_bicgstab_against_scipy(
self, shape, dtype, preconditioner):
if not config.jax_enable_x64:
@ -266,16 +254,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
args_maker,
tol=1e-4)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
for shape in [(2, 2), (7, 7)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
))
@jtu.sample_product(
shape=[(2, 2), (7, 7)],
dtype=float_types + complex_types,
preconditioner=[None, 'identity', 'exact'],
)
@jtu.skip_on_devices("gpu")
def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner):
A = jnp.eye(shape[1], dtype=dtype)
@ -290,17 +273,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner
}
for shape in [(2, 2), (4, 4)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
))
@jtu.sample_product(
shape=[(2, 2), (4, 4)],
dtype=float_types + complex_types,
preconditioner=[None, 'identity', 'exact'],
)
@jtu.skip_on_devices("gpu")
def test_bicgstab_on_random_system(self, shape, dtype, preconditioner):
rng = jtu.rand_default(self.rng())
@ -339,18 +316,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(expected, actual)
# GMRES
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}_solve_method={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner,
solve_method),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"solve_method": solve_method}
for shape in [(3, 3)]
for dtype in [np.float64, np.complex128]
for preconditioner in [None, 'identity', 'exact', 'random']
for solve_method in ['incremental', 'batched']))
@jtu.sample_product(
shape=[(3, 3)],
dtype=[np.float64, np.complex128],
preconditioner=[None, 'identity', 'exact', 'random'],
solve_method=['incremental', 'batched'],
)
def test_gmres_against_scipy(
self, shape, dtype, preconditioner, solve_method):
if not config.x64_enabled:
@ -380,27 +351,20 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
partial(scipy_gmres, M=M, restart=2, maxiter=1),
partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method),
args_maker,
tol=1e-10)
tol=1e-9)
self._CheckAgainstNumpy(
np.linalg.solve,
partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method),
args_maker,
tol=1e-10)
tol=1e-8)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}_solve_method={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner,
solve_method),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"solve_method": solve_method}
for shape in [(2, 2), (7, 7)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
for solve_method in ['batched', 'incremental']
))
@jtu.sample_product(
shape=[(2, 2), (7, 7)],
dtype=float_types + complex_types,
preconditioner=[None, 'identity', 'exact'],
solve_method=['batched', 'incremental'],
)
@jtu.skip_on_devices("gpu")
def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
solve_method):
@ -419,19 +383,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}_solve_method={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner,
solve_method),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
"solve_method": solve_method}
for shape in [(2, 2), (4, 4)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
for solve_method in ['incremental', 'batched']
))
@jtu.sample_product(
shape=[(2, 2), (4, 4)],
dtype=float_types + complex_types,
preconditioner=[None, 'identity', 'exact'],
solve_method=['incremental', 'batched'],
)
@jtu.skip_on_devices("gpu")
def test_gmres_on_random_system(self, shape, dtype, preconditioner,
solve_method):
@ -469,15 +426,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
self.assertAllClose(expected, actual)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}".format(
jtu.format_shape_dtype_string(shape, dtype),
preconditioner),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
for shape in [(2, 2), (3, 3)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity']))
@jtu.sample_product(
shape=[(2, 2), (3, 3)],
dtype=float_types + complex_types,
preconditioner=[None, 'identity'],
)
def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
"""
The Arnoldi decomposition within GMRES is correct.

View File

@ -141,7 +141,8 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
# of inputs to some reasonable intervals
op_record("zeta", 2, float_dtypes, jtu.rand_positive, False),
# TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise
op_record("expi", 1, [np.float32], jtu.rand_default, True),
op_record("expi", 1, [np.float32], partial(jtu.rand_not_small, offset=0.1),
True),
op_record("exp1", 1, [np.float32], jtu.rand_positive, True),
op_record("expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)),
]
@ -154,23 +155,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
jtu.format_shape_dtype_string(shapes, dtype),
axis, keepdims, return_sign, use_b),
# TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
"shapes": shapes, "dtype": dtype,
"axis": axis, "keepdims": keepdims,
"return_sign": return_sign, "use_b": use_b}
for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes
@jtu.sample_product(
[dict(shapes=shapes, axis=axis, use_b=use_b)
for shape_group in compatible_shapes
for use_b in [False, True]
for shapes in itertools.product(*(
(shape_group, shape_group) if use_b else (shape_group,)))
for axis in range(-max(len(shape) for shape in shapes),
max(len(shape) for shape in shapes))
for keepdims in [False, True]
for return_sign in [False, True]))
],
dtype=float_dtypes + complex_dtypes + int_dtypes,
keepdims=[False, True],
return_sign=[False, True],
)
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLogSumExp(self, shapes, dtype, axis,
@ -235,23 +232,23 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
result = lsp_special.logsumexp(1.0, b=1.0)
self.assertEqual(result, 1.0)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.test_name, shapes, dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"test_autodiff": rec.test_autodiff,
"nondiff_argnums": rec.nondiff_argnums,
"scipy_op": getattr(osp_special, rec.name),
"lax_op": getattr(lsp_special, rec.name)}
for shapes in itertools.combinations_with_replacement(all_shapes, rec.nargs)
for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)))
for rec in JAX_SPECIAL_FUNCTION_RECORDS))
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op=rec.name, rng_factory=rec.rng_factory,
test_autodiff=rec.test_autodiff,
nondiff_argnums=rec.nondiff_argnums)],
shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs),
dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)),
)
for rec in JAX_SPECIAL_FUNCTION_RECORDS
))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes,
def testScipySpecialFun(self, op, rng_factory, shapes, dtypes,
test_autodiff, nondiff_argnums):
scipy_op = getattr(osp_special, op)
lax_op = getattr(lsp_special, op)
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
args = args_maker()
@ -272,13 +269,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
atol=jtu.if_device_under_test("tpu", .1, 1e-3),
rtol=.1, eps=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_d={}".format(
jtu.format_shape_dtype_string(shape, dtype), d),
"shape": shape, "dtype": dtype, "d": d}
for shape in all_shapes
for dtype in float_dtypes
for d in [1, 2, 5]))
@jtu.sample_product(
shape=all_shapes,
dtype=float_dtypes,
d=[1, 2, 5],
)
@jax.numpy_rank_promotion('raise')
def testMultigammaln(self, shape, dtype, d):
def scipy_fun(a):
@ -322,13 +317,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_lmax={}".format(
jtu.format_shape_dtype_string(shape, dtype), l_max),
"l_max": l_max, "shape": shape, "dtype": dtype}
for l_max in [1, 2, 3, 6]
for shape in [(5,), (10,)]
for dtype in float_dtypes))
@jtu.sample_product(
l_max=[1, 2, 3, 6],
shape=[(5,), (10,)],
dtype=float_dtypes,
)
def testLpmn(self, l_max, shape, dtype):
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
args_maker = lambda: [rng(shape, dtype)]
@ -344,13 +337,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
atol=3e-3, check_dtypes=False)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_lmax={}".format(
jtu.format_shape_dtype_string(shape, dtype), l_max),
"l_max": l_max, "shape": shape, "dtype": dtype}
for l_max in [3, 4, 6, 32]
for shape in [(2,), (3,), (4,), (64,)]
for dtype in float_dtypes))
@jtu.sample_product(
l_max=[3, 4, 6, 32],
shape=[(2,), (3,), (4,), (64,)],
dtype=float_dtypes,
)
def testNormalizedLpmnValues(self, l_max, shape, dtype):
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
args_maker = lambda: [rng(shape, dtype)]
@ -432,11 +423,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name': f'_maxdegree={l_max}_inputsize={num_z}_dtype={dtype.__name__}',
'l_max': l_max, 'num_z': num_z, 'dtype': dtype}
@jtu.sample_product(
[dict(l_max=l_max, num_z=num_z)
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
for dtype in jtu.dtypes.all_integer))
],
dtype=jtu.dtypes.all_integer,
)
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
"""Tests against JIT compatibility and Numpy."""
@ -475,32 +467,18 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name':
'_shape={}'
'_n_zero_sv={}_degeneracy={}_geometric_spectrum={}'
'_max_sv={}_method={}_side={}'
'_nonzero_condition_number={}_seed={}'.format(
jtu.format_shape_dtype_string(
shape, jnp.dtype(dtype).name).replace(" ", ""),
n_zero_sv, degeneracy, geometric_spectrum, max_sv,
method, side, nonzero_condition_number, seed
),
'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy,
'geometric_spectrum': geometric_spectrum,
'max_sv': max_sv, 'shape': shape, 'method': method,
'side': side, 'nonzero_condition_number': nonzero_condition_number,
'dtype': dtype, 'seed': seed}
for n_zero_sv in n_zero_svs
for degeneracy in degeneracies
for geometric_spectrum in geometric_spectra
for max_sv in max_svs
for shape in polar_shapes
for method in methods
for side in sides
for nonzero_condition_number in nonzero_condition_numbers
for dtype in jtu.dtypes.inexact
for seed in seeds))
@jtu.sample_product(
n_zero_sv=n_zero_svs,
degeneracy=degeneracies,
geometric_spectrum=geometric_spectra,
max_sv=max_svs,
shape=polar_shapes,
method=methods,
side=sides,
nonzero_condition_number=nonzero_condition_numbers,
dtype=jtu.dtypes.inexact,
seed=seeds,
)
def testPolar(
self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method,
side, nonzero_condition_number, dtype, seed):
@ -552,15 +530,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(
matrix, recon, atol=tol * jnp.linalg.norm(matrix))
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name':
'_linear_size={}_dtype={}_termination_size={}'.format(
linear_size, jnp.dtype(dtype).name, termination_size),
'linear_size': linear_size, 'dtype': dtype,
'termination_size': termination_size}
for linear_size in linear_sizes
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for termination_size in [1, 19]))
@jtu.sample_product(
linear_size=linear_sizes,
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
termination_size=[1, 19],
)
def test_spectral_dac_eigh(self, linear_size, dtype, termination_size):
if jtu.device_under_test() != "tpu" and termination_size != 1:
raise unittest.SkipTest(
@ -584,13 +558,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
HV, vV, atol=atol * (80 if jnp.issubdtype(dtype, jnp.complexfloating)
else 30))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{jtu.format_shape_dtype_string((n_obs, n_codes, *n_feats), dtype)}",
"n_obs": n_obs, "n_codes": n_codes, "n_feats": n_feats, "dtype": dtype}
for n_obs in [1, 3, 5]
for n_codes in [1, 2, 4]
for n_feats in [()] + [(i,) for i in range(1, 3)]
for dtype in float_dtypes + int_dtypes)) # scipy doesn't support complex
@jtu.sample_product(
n_obs=[1, 3, 5],
n_codes=[1, 2, 4],
n_feats=[()] + [(i,) for i in range(1, 3)],
dtype=float_dtypes + int_dtypes, # scipy doesn't support complex
)
def test_vq(self, n_obs, n_codes, n_feats, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng((n_obs, *n_feats), dtype), rng((n_codes, *n_feats), dtype)]

View File

@ -16,7 +16,6 @@
from functools import partial
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import numpy.random as npr
@ -35,12 +34,7 @@ npr.seed(0)
class MultiBackendTest(jtu.JaxTestCase):
"""Tests jit targeting to different backends."""
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_backend={backend}",
"backend": backend,
}
for backend in ['cpu', 'gpu', 'tpu', None]
))
@jtu.sample_product(backend=['cpu', 'gpu', 'tpu', None])
def testMultiBackend(self, backend):
if backend not in ('cpu', jtu.device_under_test(), None):
raise SkipTest("Backend is not CPU or the device under test")
@ -56,10 +50,9 @@ class MultiBackendTest(jtu.JaxTestCase):
correct_platform = backend if backend else jtu.device_under_test()
self.assertEqual(z.device().platform, correct_platform)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_ordering={ordering}",
"ordering": ordering,}
for ordering in [('cpu', None), ('gpu', None), ('tpu', None), (None, None)]))
@jtu.sample_product(
ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)]
)
def testMultiBackendNestedJit(self, ordering):
outer, inner = ordering
if outer not in ('cpu', jtu.device_under_test(), None):
@ -81,14 +74,11 @@ class MultiBackendTest(jtu.JaxTestCase):
correct_platform = outer if outer else jtu.device_under_test()
self.assertEqual(z.device().platform, correct_platform)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_ordering={ordering}",
"ordering": ordering,}
for ordering in [
('cpu', 'gpu'), ('gpu', 'cpu'),
('cpu', 'tpu'), ('tpu', 'cpu'),
(None, 'cpu'), (None, 'gpu'), (None, 'tpu'),
]))
@jtu.sample_product(
ordering=[('cpu', 'gpu'), ('gpu', 'cpu'), ('cpu', 'tpu'), ('tpu', 'cpu'),
(None, 'cpu'), (None, 'gpu'), (None, 'tpu'),
],
)
def testMultiBackendNestedJitConflict(self, ordering):
outer, inner = ordering
if outer not in ('cpu', jtu.device_under_test(), None):
@ -111,11 +101,7 @@ class MultiBackendTest(jtu.JaxTestCase):
y = npr.uniform(size=(10, 10))
self.assertRaises(ValueError, lambda: fun(x, y))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_backend={backend}",
"backend": backend,}
for backend in ['cpu', 'gpu', 'tpu']
))
@jtu.sample_product(backend=['cpu', 'gpu', 'tpu'])
def testGpuMultiBackendOpByOpReturn(self, backend):
if backend not in ('cpu', jtu.device_under_test()):
raise SkipTest("Backend is not CPU or the device under test")

View File

@ -262,9 +262,7 @@ class PrngTest(jtu.JaxTestCase):
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
self.assertArraysEqual(bits64, expected64)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_" + name, "prng_name": name}
for name, _ in PRNG_IMPLS))
@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS])
def testRngRandomBitsShapeDtype(self, prng_name):
# Like testRngRandomBits, but only meant to exercise random_bits
# on every PRNG implementation. Instead of values, only checks
@ -313,9 +311,7 @@ class PrngTest(jtu.JaxTestCase):
assert np.all(rand_bits_32 == rand_bits_32[0])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": case._testname(), "case": case}
for case in _RANDOM_VALUES_CASES))
@jtu.sample_product(case=_RANDOM_VALUES_CASES)
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testRandomDistributionValues(self, case):
"""
@ -366,8 +362,7 @@ class PrngTest(jtu.JaxTestCase):
_prng_key_as_array(random.fold_in(k, 4)),
np.array([2285895361, 433833334], dtype='uint32'))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "seed={seed}_type={type}_jit={jit}".format(**dct), **dct} for dct in [
@jtu.sample_product([
{"seed": 0, "type": int, "jit": True, "key": [0, 0]},
{"seed": 0, "type": int, "jit": False, "key": [0, 0]},
{"seed": 1, "type": np.int32, "jit": True, "key": [0, 1]},
@ -390,8 +385,7 @@ class PrngTest(jtu.JaxTestCase):
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
]
))
])
def test_prng_seeds_and_keys(self, seed, type, jit, key):
if (jit and type is int and not config.x64_enabled and
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
@ -525,9 +519,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def seed_prng(self, seed):
return random.threefry2x32_key(seed)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in jtu.dtypes.floating))
@jtu.sample_product(dtype=jtu.dtypes.floating)
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
numpy_bits = np.array(1., dtype).view(bits_dtype)
@ -535,9 +527,7 @@ class LaxRandomTest(jtu.JaxTestCase):
lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))()
self.assertEqual(numpy_bits, xla_bits)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in float_dtypes))
@jtu.sample_product(dtype=float_dtypes)
def testRngUniform(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.uniform(key, (10000,), dtype)
@ -550,9 +540,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in int_dtypes + uint_dtypes))
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
def testRngRandint(self, dtype):
lo = 5
hi = 10
@ -568,9 +556,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertTrue(np.all(lo <= samples))
self.assertTrue(np.all(samples < hi))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in float_dtypes))
@jtu.sample_product(dtype=float_dtypes)
def testNormal(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.normal(key, (10000,), dtype)
@ -589,9 +575,7 @@ class LaxRandomTest(jtu.JaxTestCase):
res_bfloat16 = random.normal(self.seed_prng(0), dtype=jnp.bfloat16)
self.assertAllClose(res_bfloat16, res_bfloat16_str)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in complex_dtypes))
@jtu.sample_product(dtype=complex_dtypes)
def testNormalComplex(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.normal(key, (10000,), dtype)
@ -605,9 +589,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
self.assertEqual(dtype, samples.dtype)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in float_dtypes))
@jtu.sample_product(dtype=float_dtypes)
def testTruncatedNormal(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
@ -623,9 +605,7 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
@jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer)
def testShuffle(self, dtype):
key = self.seed_prng(0)
x = np.arange(100).astype(dtype)
@ -641,23 +621,21 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
dict(
testcase_name=
f"_{np.dtype(dtype).name}_input_range_or_shape={input_range_or_shape}"
f"_shape={shape}_replace={replace}_weighted={weighted}_axis={axis}",
dtype=dtype, input_range_or_shape=input_range_or_shape,
shape=shape, replace=replace, weighted=weighted, axis=axis)
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for shape in [(), (5,), (4, 5)]
for replace in [True, False]
for weighted in [True, False]
for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)]
for is_range in [type(input_range_or_shape) is int]
for ndim in [1 if is_range else len(input_range_or_shape)]
for axis in range(-ndim, ndim or 1)
for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]]
if replace or np.prod(shape) <= ninputs))
@jtu.sample_product(
[dict(shape=shape, replace=replace, axis=axis,
input_range_or_shape=input_range_or_shape)
for shape in [(), (5,), (4, 5)]
for replace in [True, False]
for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)]
for is_range in [type(input_range_or_shape) is int]
for ndim in [1 if is_range else len(input_range_or_shape)]
for axis in range(-ndim, ndim or 1)
for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]]
if replace or np.prod(shape) <= ninputs
],
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
weighted=[True, False],
)
def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis):
# This is the function API that we test against (note that self.rng().choice differs)
np_choice = np.random.default_rng(0).choice
@ -690,17 +668,16 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertArraysEqual(sample, jax.jit(rand, static_argnames=
'x' if is_range else None)(key, x))
@parameterized.named_parameters(jtu.cases_from_list(
dict(
testcase_name=f"_dtype={dtype}_range_or_shape={range_or_shape}"
f"_axis={axis}_independent={independent}",
dtype=dtype, range_or_shape=range_or_shape, axis=axis, independent=independent)
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for range_or_shape in [0, 1, 100, (0,), (1,), (100,),
(10, 10), (10, 5, 2), (0, 5), (1, 5)]
for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)]
for axis in range(-ndim, ndim or 1)
for independent in [True, False]))
@jtu.sample_product(
[dict(range_or_shape=range_or_shape, axis=axis)
for range_or_shape in [0, 1, 100, (0,), (1,), (100,),
(10, 10), (10, 5, 2), (0, 5), (1, 5)]
for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)]
for axis in range(-ndim, ndim or 1)
],
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
independent=[True, False],
)
def testPermutation(self, dtype, range_or_shape, axis, independent):
key = self.seed_prng(0)
is_range = type(range_or_shape) is int
@ -739,11 +716,10 @@ class LaxRandomTest(jtu.JaxTestCase):
with self.assertRaises(core.ConcretizationTypeError):
jax.jit(random.permutation)(key, 10)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_p={p}_dtype={np.dtype(dtype).name}",
"p": p, "dtype": dtype}
for p in [0.1, 0.5, 0.9]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
p=[0.1, 0.5, 0.9],
dtype=jtu.dtypes.floating,
)
def testBernoulli(self, p, dtype):
key = self.seed_prng(0)
p = np.array(p, dtype=dtype)
@ -756,17 +732,18 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_p={p}_{np.dtype(dtype).name}_{sample_shape}",
"p": p, "axis": axis, "dtype": dtype, 'sample_shape': sample_shape}
for (p, axis) in [
@jtu.sample_product(
[dict(p=p, axis=axis)
for (p, axis) in [
([.25] * 4, -1),
([.1, .2, .3, .4], -1),
([[.5, .5], [.1, .9]], 1),
([[.5, .1], [.5, .9]], 0),
]
for sample_shape in [(10000,), (5000, 2)]
for dtype in jtu.dtypes.floating))
]
],
sample_shape=[(10000,), (5000, 2)],
dtype=jtu.dtypes.floating,
)
def testCategorical(self, p, axis, dtype, sample_shape):
key = self.seed_prng(0)
p = np.array(p, dtype=dtype)
@ -800,12 +777,11 @@ class LaxRandomTest(jtu.JaxTestCase):
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_a={a}_b={b}_dtype={np.dtype(dtype).name}",
"a": a, "b": b, "dtype": dtype}
for a in [0.2, 5.]
for b in [0.2, 5.]
for dtype in [np.float64])) # NOTE: KS test fails with float32
@jtu.sample_product(
a=[0.2, 5.],
b=[0.2, 5.],
dtype=[np.float64], # NOTE: KS test fails with float32
)
def testBeta(self, a, b, dtype):
if not config.x64_enabled:
raise SkipTest("skip test except on X64")
@ -832,9 +808,7 @@ class LaxRandomTest(jtu.JaxTestCase):
ones = samples[samples >= 0.5]
self.assertAllClose(ones, jnp.ones_like(ones))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in float_dtypes))
@jtu.sample_product(dtype=float_dtypes)
def testCauchy(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
@ -846,13 +820,10 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_alpha={alpha}_dtype={np.dtype(dtype).name}",
"alpha": alpha, "dtype": dtype}
for alpha in [
np.array([0.2, 1., 5.]),
]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
alpha=[np.array([0.2, 1., 5.]),],
dtype=jtu.dtypes.floating,
)
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
def testDirichlet(self, alpha, dtype):
key = self.seed_prng(0)
@ -883,9 +854,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertAllClose(samples.max(1), jnp.ones(samples.shape[0]),
check_dtypes=False, rtol=1E-5)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in float_dtypes))
@jtu.sample_product(dtype=float_dtypes)
def testExponential(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.exponential(key, (10000,), dtype)
@ -897,11 +866,10 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_a={a}_dtype={np.dtype(dtype).name}",
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
a=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
def testGammaVsLogGamma(self, a, dtype):
key = self.seed_prng(0)
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
@ -911,11 +879,10 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)))
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_a={a}_dtype={np.dtype(dtype).name}",
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
a=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
def testGamma(self, a, dtype):
key = self.seed_prng(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
@ -932,11 +899,10 @@ class LaxRandomTest(jtu.JaxTestCase):
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_a={alpha}_logspace={log_space}",
"alpha": alpha, "log_space": log_space}
for log_space in [True, False]
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
@jtu.sample_product(
log_space=[True, False],
alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4],
)
def testGammaGrad(self, log_space, alpha):
rng = self.seed_prng(0)
alphas = np.full((100,), alpha)
@ -969,11 +935,10 @@ class LaxRandomTest(jtu.JaxTestCase):
# Should not crash with a type error.
jax.vjp(f, a, b)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_lam={lam}_dtype={np.dtype(dtype).name}",
"lam": lam, "dtype": np.dtype(dtype)}
for lam in [0.5, 3, 9, 11, 50, 500]
for dtype in [np.int16, np.int32, np.int64]))
@jtu.sample_product(
lam=[0.5, 3, 9, 11, 50, 500],
dtype=[np.int16, np.int32, np.int64],
)
def testPoisson(self, lam, dtype):
key = self.seed_prng(0)
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
@ -1019,9 +984,7 @@ class LaxRandomTest(jtu.JaxTestCase):
samples = random.poisson(key, lam, shape=(3,))
self.assertArraysEqual(samples, jnp.array([-1, 0, -1]))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in jtu.dtypes.floating))
@jtu.sample_product(dtype=jtu.dtypes.floating)
def testGumbel(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.gumbel(key, (10000,), dtype)
@ -1033,9 +996,7 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in float_dtypes))
@jtu.sample_product(dtype=float_dtypes)
def testLaplace(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.laplace(key, (10000,), dtype)
@ -1047,9 +1008,7 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
for dtype in float_dtypes))
@jtu.sample_product(dtype=float_dtypes)
def testLogistic(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.logistic(key, (10000,), dtype)
@ -1061,15 +1020,11 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}_shape={}"\
.format(n, jtu.format_shape_dtype_string(shape, dtype)),
"n": n,
"shape": shape,
"dtype": dtype}
for n in range(1, 5)
for shape in [(), (5,), (10, 5)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
@jtu.sample_product(
n=range(1, 5),
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
def testOrthogonal(self, n, shape, dtype):
key = self.seed_prng(0)
q = random.orthogonal(key, n, shape, dtype)
@ -1083,15 +1038,11 @@ class LaxRandomTest(jtu.JaxTestCase):
atol=tol, rtol=tol,
)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_shape={}"\
.format(p, jtu.format_shape_dtype_string(shape, dtype)),
"p": p,
"shape": shape,
"dtype": dtype}
for p in [.5, 1., 1.5, 2., 2.5]
for shape in [(), (5,), (10, 5)]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
p=[.5, 1., 1.5, 2., 2.5],
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating,
)
def testGeneralizedNormal(self, p, shape, dtype):
key = self.seed_prng(0)
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
@ -1103,17 +1054,12 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertEqual(samples.dtype, dtype)
self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_d={}_p={}_shape={}"\
.format(d, p, jtu.format_shape_dtype_string(shape, dtype)),
"d": d,
"p": p,
"shape": shape,
"dtype": dtype}
for d in range(1, 5)
for p in [.5, 1., 1.5, 2., 2.5]
for shape in [(), (5,), (10, 5)]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
d=range(1, 5),
p=[.5, 1., 1.5, 2., 2.5],
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating,
)
def testBall(self, d, p, shape, dtype):
key = self.seed_prng(0)
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
@ -1127,11 +1073,10 @@ class LaxRandomTest(jtu.JaxTestCase):
norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p)
self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_b={b}_dtype={np.dtype(dtype).name}",
"b": b, "dtype": dtype}
for b in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
b=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
def testPareto(self, b, dtype):
key = self.seed_prng(0)
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
@ -1149,11 +1094,10 @@ class LaxRandomTest(jtu.JaxTestCase):
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_df={df}_dtype={np.dtype(dtype).name}",
"df": df, "dtype": dtype}
for df in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
df=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
def testT(self, df, dtype):
key = self.seed_prng(0)
@ -1166,13 +1110,11 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dim={}_dtype={}_method={}".format(
dim, np.dtype(dtype), method),
"dim": dim, "dtype": dtype, "method": method}
for dim in [1, 3, 5]
for dtype in float_dtypes
for method in ['svd', 'eigh', 'cholesky']))
@jtu.sample_product(
dim=[1, 3, 5],
dtype=float_dtypes,
method=['svd', 'eigh', 'cholesky'],
)
def testMultivariateNormal(self, dim, dtype, method):
r = self.rng()
mean = r.randn(dim)
@ -1198,18 +1140,13 @@ class LaxRandomTest(jtu.JaxTestCase):
# eigenvectors follow a standard normal distribution.
self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dim={}_mean_batch_size={}_cov_batch_size={}_shape={}_method={}"\
.format(dim, mean_batch_size, cov_batch_size, shape, method),
"dim": dim,
"mean_batch_size": mean_batch_size,
"cov_batch_size": cov_batch_size,
"shape": shape, "method": method}
for dim in [1, 2, 4]
for mean_batch_size in [(), (3,), (2, 3)]
for cov_batch_size in [(), (3,), (2, 3)]
for shape in [(), (1,), (5,)]
for method in ['cholesky', 'svd', 'eigh']))
@jtu.sample_product(
dim=[1, 2, 4],
mean_batch_size=[(), (3,), (2, 3)],
cov_batch_size=[(), (3,), (2, 3)],
shape=[(), (1,), (5,)],
method=['cholesky', 'svd', 'eigh'],
)
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
shape, method):
r = self.rng()
@ -1426,10 +1363,10 @@ class LaxRandomTest(jtu.JaxTestCase):
with jax.enable_checks(False): # check_jaxpr will materialize array
jax.eval_shape(f, 0) # doesn't error
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_seed={seed}_type={type_}", "seed": seed, "type_": type_}
for type_ in ["int", "np.array", "jnp.array"]
for seed in [-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)]))
@jtu.sample_product(
type_=["int", "np.array", "jnp.array"],
seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)],
)
def test_prng_jit_invariance(self, seed, type_):
if type_ == "int" and seed == (1 << 64) - 1:
self.skipTest("Expected failure: Python int too large.")
@ -1453,9 +1390,7 @@ class LaxRandomTest(jtu.JaxTestCase):
jax.jit(random.split)(key)
self.assertLessEqual(count[0], 1) # 1 for the argument device_put
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype}", "dtype": dtype}
for dtype in int_dtypes + uint_dtypes))
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
def test_randint_bounds(self, dtype):
min = np.iinfo(dtype).min
max = np.iinfo(dtype).max
@ -1541,10 +1476,7 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertIsInstance(keys, random.KeyArray)
self.assertEqual(keys.shape, (3,))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_internal" if use_internal else "",
"use_internal": use_internal}
for use_internal in [False, True]))
@jtu.sample_product(use_internal=[False, True])
def test_random_unwrap(self, use_internal):
unwrap = prng_internal.random_unwrap if use_internal else random.key_data
def f(k): return unwrap(k)

View File

@ -16,7 +16,7 @@
from functools import partial
import unittest
from absl.testing import absltest, parameterized
from absl.testing import absltest
import numpy as np
import scipy.signal as osp_signal
@ -70,22 +70,19 @@ def _complex_dtype(dtype):
class LaxBackedScipySignalTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.stats implementations"""
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format(
op,
jtu.format_shape_dtype_string(xshape, dtype),
jtu.format_shape_dtype_string(yshape, dtype),
mode),
"xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode,
"jsp_op": getattr(jsp_signal, op),
"osp_op": getattr(osp_signal, op)}
for mode in ['full', 'same', 'valid']
for op in ['convolve', 'correlate']
for dtype in default_dtypes
for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
for xshape in shapeset
for yshape in shapeset))
def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
@jtu.sample_product(
[dict(xshape=xshape, yshape=yshape)
for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
for xshape in shapeset
for yshape in shapeset
],
mode=['full', 'same', 'valid'],
op=['convolve', 'correlate'],
dtype=default_dtypes,
)
def testConvolutions(self, xshape, yshape, dtype, mode, op):
jsp_op = getattr(jsp_signal, op)
osp_op = getattr(osp_signal, op)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
osp_fun = partial(osp_op, mode=mode)
@ -94,21 +91,16 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_xshape={}_yshape={}_mode={}".format(
op,
jtu.format_shape_dtype_string(xshape, dtype),
jtu.format_shape_dtype_string(yshape, dtype),
mode),
"xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode,
"jsp_op": getattr(jsp_signal, op),
"osp_op": getattr(osp_signal, op)}
for mode in ['full', 'same', 'valid']
for op in ['convolve2d', 'correlate2d']
for dtype in default_dtypes
for xshape in twodim_shapes
for yshape in twodim_shapes))
def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
@jtu.sample_product(
mode=['full', 'same', 'valid'],
op=['convolve2d', 'correlate2d'],
dtype=default_dtypes,
xshape=twodim_shapes,
yshape=twodim_shapes,
)
def testConvolutions2D(self, xshape, yshape, dtype, mode, op):
jsp_op = getattr(jsp_signal, op)
osp_op = getattr(osp_signal, op)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
osp_fun = partial(osp_op, mode=mode)
@ -118,15 +110,13 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_type={}_bp={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
"shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp}
for shape in [(5,), (4, 5), (3, 4, 5)]
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for axis in [0, -1]
for type in ['constant', 'linear']
for bp in [0, [0, 2]]))
@jtu.sample_product(
shape=[(5,), (4, 5), (3, 4, 5)],
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
axis=[0, -1],
type=['constant', 'linear'],
bp=[0, [0, 2]],
)
def testDetrend(self, shape, dtype, axis, type, bp):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@ -144,24 +134,19 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}"
f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}"
f"_axis={timeaxis}_nfft={nfft}",
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
"detrend": detrend, "boundary": boundary, "padded": padded,
"timeaxis": timeaxis}
@jtu.sample_product(
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis,
nfft=nfft)
for shape, nperseg, noverlap, timeaxis in stft_test_shapes
for dtype in default_dtypes
for fs in [1.0, 16000.0]
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
for detrend in ['constant', 'linear', False]
for boundary in [None, 'even', 'odd', 'zeros']
for padded in [True, False]))
],
dtype=default_dtypes,
fs=[1.0, 16000.0],
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
detrend=['constant', 'linear', False],
boundary=[None, 'even', 'odd', 'zeros'],
padded=[True, False],
)
def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
noverlap, nfft, detrend, boundary, padded,
timeaxis):
@ -193,26 +178,19 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
# Tests with `average == 'median'`` is excluded from `testCsd*`
# due to the issue:
# https://github.com/scipy/scipy/issues/15601
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}"
f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}"
f"_average={average}_scaling={scaling}_nfft={nfft}"
f"_fs={fs}_window={window}_detrend={detrend}"
f"_nperseg={nperseg}_noverlap={noverlap}"
f"_axis={timeaxis}",
"xshape": xshape, "yshape": yshape, "dtype": dtype, "fs": fs,
"window": window, "nperseg": nperseg, "noverlap": noverlap,
"nfft": nfft, "detrend": detrend, "scaling": scaling,
"timeaxis": timeaxis, "average": average}
@jtu.sample_product(
[dict(xshape=xshape, yshape=yshape, nperseg=nperseg, noverlap=noverlap,
timeaxis=timeaxis, nfft=nfft)
for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes
for dtype in default_dtypes
for fs in [1.0, 16000.0]
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
for detrend in ['constant', 'linear', False]
for scaling in ['density', 'spectrum']
for average in ['mean']))
],
dtype=default_dtypes,
fs=[1.0, 16000.0],
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
detrend=['constant', 'linear', False],
scaling=['density', 'spectrum'],
average=['mean'],
)
def testCsdAgainstNumpy(
self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft,
detrend, scaling, timeaxis, average):
@ -242,25 +220,19 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
f"_average={average}_scaling={scaling}_nfft={nfft}"
f"_fs={fs}_window={window}_detrend={detrend}"
f"_nperseg={nperseg}_noverlap={noverlap}"
f"_axis={timeaxis}",
"shape": shape, "dtype": dtype, "fs": fs,
"window": window, "nperseg": nperseg, "noverlap": noverlap,
"nfft": nfft, "detrend": detrend, "scaling": scaling,
"timeaxis": timeaxis, "average": average}
for shape, unused_yshape, nperseg, noverlap, timeaxis in csd_test_shapes
for dtype in default_dtypes
for fs in [1.0, 16000.0]
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
@jtu.sample_product(
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis,
nfft=nfft)
for shape, _yshape, nperseg, noverlap, timeaxis in csd_test_shapes
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
for detrend in ['constant', 'linear', False]
for scaling in ['density', 'spectrum']
for average in ['mean']))
],
dtype=default_dtypes,
fs=[1.0, 16000.0],
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
detrend=['constant', 'linear', False],
scaling=['density', 'spectrum'],
average=['mean'],
)
def testCsdWithSameParamAgainstNumpy(
self, *, shape, dtype, fs, window, nperseg, noverlap, nfft,
detrend, scaling, timeaxis, average):
@ -292,26 +264,20 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
f"_fs={fs}_window={window}"
f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}"
f"_detrend={detrend}_return_onesided={return_onesided}"
f"_scaling={scaling}_axis={timeaxis}_average={average}",
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
"detrend": detrend, "return_onesided": return_onesided,
"scaling": scaling, "timeaxis": timeaxis, "average": average}
@jtu.sample_product(
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis,
nfft=nfft)
for shape, nperseg, noverlap, timeaxis in welch_test_shapes
for dtype in default_dtypes
for fs in [1.0, 16000.0]
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
for detrend in ['constant', 'linear', False]
for return_onesided in [True, False]
for scaling in ['density', 'spectrum']
for average in ['mean', 'median']))
],
dtype=default_dtypes,
fs=[1.0, 16000.0],
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
detrend=['constant', 'linear', False],
return_onesided=[True, False],
scaling=['density', 'spectrum'],
average=['mean', 'median'],
)
def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
noverlap, nfft, detrend, return_onesided,
scaling, timeaxis, average):
@ -342,20 +308,14 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
f"_nperseg={nperseg}_noverlap={noverlap}"
f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}"
f"_axis={timeaxis}",
"shape": shape, "dtype": dtype,
"nperseg": nperseg, "noverlap": noverlap,
"use_nperseg": use_nperseg, "use_noverlap": use_noverlap,
"timeaxis": timeaxis}
@jtu.sample_product(
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis)
for shape, nperseg, noverlap, timeaxis in welch_test_shapes
for use_nperseg in [False, True]
for use_noverlap in [False, True]
for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
],
use_nperseg=[False, True],
use_noverlap=[False, True],
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
)
def testWelchWithDefaultStepArgsAgainstNumpy(
self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap,
timeaxis):
@ -386,23 +346,18 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
f"_fs={fs}_window={window}_boundary={boundary}"
f"_nperseg={nperseg}_noverlap={noverlap}_onesided={onesided}"
f"_timeaxis={timeaxis}_freqaxis{freqaxis}_nfft={nfft}",
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
"onesided": onesided, "boundary": boundary,
"timeaxis": timeaxis, "freqaxis": freqaxis}
@jtu.sample_product(
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis,
freqaxis=freqaxis, nfft=nfft)
for shape, nperseg, noverlap, timeaxis, freqaxis in istft_test_shapes
for dtype in default_dtypes
for fs in [1.0, 16000.0]
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
for onesided in [False, True]
for boundary in [False, True]))
],
dtype=default_dtypes,
fs=[1.0, 16000.0],
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
onesided=[False, True],
boundary=[False, True],
)
def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
noverlap, nfft, onesided, boundary,
timeaxis, freqaxis):

View File

@ -16,7 +16,7 @@
from functools import partial
import itertools
from absl.testing import absltest, parameterized
from absl.testing import absltest
import numpy as np
import scipy.stats as osp_stats
@ -34,12 +34,10 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
def genNamedParametersNArgs(n):
return parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
"shapes": shapes, "dtypes": dtypes}
for shapes in itertools.combinations_with_replacement(all_shapes, n)
for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
return jtu.sample_product(
shapes=itertools.combinations_with_replacement(all_shapes, n),
dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, n),
)
# Allow implicit rank promotion in these tests, as virtually every test exercises it.
@ -179,14 +177,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes),
"shapes": [x_shape, alpha_shape], "dtypes": dtypes}
@jtu.sample_product(
shapes=[
[x_shape, alpha_shape]
for x_shape in one_and_two_dim_shapes
for alpha_shape in [(x_shape[0],), (x_shape[0] + 1,)]
for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, 2)
))
],
dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, 2),
)
def testDirichletLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
@ -564,13 +562,12 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)),
check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", [shape, shape], [x_dtype, p_dtype]),
"x_dtype": x_dtype,
"p_dtype": p_dtype,
"shape": shape}
for shape in [(2), (4,), (1, 5)]
for (x_dtype, p_dtype) in itertools.product(jtu.dtypes.integer, jtu.dtypes.floating)))
@jtu.sample_product(
[dict(x_dtype=x_dtype, p_dtype=p_dtype)
for x_dtype, p_dtype in itertools.product(jtu.dtypes.integer, jtu.dtypes.floating)
],
shape=[(2), (4,), (1, 5)],
)
def testMultinomialLogPmf(self, shape, x_dtype, p_dtype):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.multinomial.logpmf
@ -588,16 +585,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_mean={}_cov={}".format(
jtu.format_shape_dtype_string(x_shape, x_dtype),
jtu.format_shape_dtype_string(mean_shape, mean_dtype)
if mean_shape is not None else None,
jtu.format_shape_dtype_string(cov_shape, cov_dtype)
if cov_shape is not None else None),
"x_shape": x_shape, "x_dtype": x_dtype,
"mean_shape": mean_shape, "mean_dtype": mean_dtype,
"cov_shape": cov_shape, "cov_dtype": cov_dtype}
@jtu.sample_product(
[dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape)
for x_shape, mean_shape, cov_shape in [
# # These test cases cover default values for mean/cov, but we don't
# # support those yet (and they seem not very valuable).
@ -616,9 +605,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
[(3, 4), (4,), (4, 4)],
[(2, 3, 4), (4,), (4, 4)],
]
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
if (mean_shape is not None or mean_dtype == np.float32)
and (cov_shape is not None or cov_dtype == np.float32)))
],
[dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype)
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
],
# if (mean_shape is not None or mean_dtype == np.float32)
# and (cov_shape is not None or cov_dtype == np.float32)))
)
def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
mean_dtype, cov_shape, cov_dtype):
rng = jtu.rand_default(self.rng())
@ -642,16 +635,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
rtol=1e-4, atol=1e-4)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_mean={}_cov={}".format(
jtu.format_shape_dtype_string(x_shape, x_dtype),
jtu.format_shape_dtype_string(mean_shape, mean_dtype)
if mean_shape is not None else None,
jtu.format_shape_dtype_string(cov_shape, cov_dtype)
if cov_shape is not None else None),
"x_shape": x_shape, "x_dtype": x_dtype,
"mean_shape": mean_shape, "mean_dtype": mean_dtype,
"cov_shape": cov_shape, "cov_dtype": cov_dtype}
@jtu.sample_product(
[dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape)
for x_shape, mean_shape, cov_shape in [
# These test cases are where scipy flattens things, which has
# different batch semantics than some might expect, so we manually
@ -663,9 +648,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
[(1, 3, 2), (3, 2,), (5, 1, 2, 2)],
[(5, 3, 2), (1, 2,), (2, 2)],
]
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
if (mean_shape is not None or mean_dtype == np.float32)
and (cov_shape is not None or cov_dtype == np.float32)))
],
[dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype)
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
],
)
def testMultivariateNormalLogpdfBroadcasted(self, x_shape, x_dtype, mean_shape,
mean_dtype, cov_shape, cov_dtype):
rng = jtu.rand_default(self.rng())
@ -691,12 +678,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
rtol=1e-4, atol=1e-4)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_ndim={ndim}_nbatch={nbatch}_dtype={dtype.__name__}",
"ndim": ndim, "nbatch": nbatch, "dtype": dtype}
for ndim in [2, 3]
for nbatch in [1, 3, 5]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
ndim=[2, 3],
nbatch=[1, 3, 5],
dtype=jtu.dtypes.floating,
)
def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype):
# Regression test for #5570
rng = jtu.rand_default(self.rng())
@ -709,23 +695,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
self.assertArraysEqual(result1, result2, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_inshape={}_outsize={}_weights={}_method={}_func={}".format(
jtu.format_shape_dtype_string(inshape, dtype),
outsize, weights, method, func),
"dtype": dtype,
"inshape": inshape,
"outsize": outsize,
"weights": weights,
"method": method,
"func": func}
for inshape in [(50,), (3, 50), (2, 12)]
for dtype in jtu.dtypes.floating
for outsize in [None, 10]
for weights in [False, True]
for method in [None, "scott", "silverman", 1.5, "callable"]
for func in [None, "evaluate", "logpdf", "pdf"]))
@jtu.sample_product(
inshape=[(50,), (3, 50), (2, 12)],
dtype=jtu.dtypes.floating,
outsize=[None, 10],
weights=[False, True],
method=[None, "scott", "silverman", 1.5, "callable"],
func=[None, "evaluate", "logpdf", "pdf"],
)
def testKde(self, inshape, dtype, outsize, weights, method, func):
if method == "callable":
method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))
@ -764,12 +741,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
"dtype": dtype,
"shape": shape}
for shape in [(15,), (3, 15), (1, 12)]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
shape=[(15,), (3, 15), (1, 12)],
dtype=jtu.dtypes.floating,
)
def testKdeIntegrateGaussian(self, shape, dtype):
def scipy_fun(dataset, weights):
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
@ -796,12 +771,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
"dtype": dtype,
"shape": shape}
for shape in [(15,), (12,)]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
shape=[(15,), (12,)],
dtype=jtu.dtypes.floating,
)
def testKdeIntegrateBox1d(self, shape, dtype):
def scipy_fun(dataset, weights):
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
@ -820,12 +793,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
"dtype": dtype,
"shape": shape}
for shape in [(15,), (3, 15), (1, 12)]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
shape=[(15,), (3, 15), (1, 12)],
dtype=jtu.dtypes.floating,
)
def testKdeIntegrateKde(self, shape, dtype):
def scipy_fun(dataset, weights):
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
@ -848,12 +819,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
"dtype": dtype,
"shape": shape}
for shape in [(15,), (3, 15), (1, 12)]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
shape=[(15,), (3, 15), (1, 12)],
dtype=jtu.dtypes.floating,
)
def testKdeResampleShape(self, shape, dtype):
def resample(key, dataset, weights, *, shape):
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
@ -878,12 +847,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
result = func(*args)
assert result.shape == (ndim, 4)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
"dtype": dtype,
"shape": shape}
for shape in [(15,), (1, 12)]
for dtype in jtu.dtypes.floating))
@jtu.sample_product(
shape=[(15,), (1, 12)],
dtype=jtu.dtypes.floating,
)
def testKdeResample1d(self, shape, dtype):
rng = jtu.rand_default(self.rng())
dataset = rng(shape, dtype)

File diff suppressed because it is too large Load Diff

View File

@ -205,18 +205,15 @@ class SparsifyTest(jtu.JaxTestCase):
self.assertAllClose(out.todense(), x.todense() + y.todense())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}_unique_indices={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense,
unique_indices),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense,
"unique_indices": unique_indices}
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in (jtu.dtypes.integer + jtu.dtypes.floating +
jtu.dtypes.complex)
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)
for unique_indices in [True, False]))
],
dtype=jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex,
unique_indices=[True, False],
)
def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices):
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch,
@ -281,10 +278,8 @@ class SparsifyTest(jtu.JaxTestCase):
res_sparse = res_sparse.todense()
self.assertArraysAllClose(res_dense, res_sparse)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_dimensions={}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, np.float32), dimensions, n_batch, n_dense),
"shape": shape, "dimensions": dimensions, "n_batch": n_batch, "n_dense": n_dense}
@jtu.sample_product(
[dict(shape=shape, dimensions=dimensions, n_batch=n_batch, n_dense=n_dense)
for shape, dimensions in [
[(1,), (0,)],
[(1,), (-1,)],
@ -294,7 +289,9 @@ class SparsifyTest(jtu.JaxTestCase):
[(2, 1, 3, 1), (3,)],
]
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) - n_batch + 1)))
for n_dense in range(len(shape) - n_batch + 1)
],
)
def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):
rng = jtu.rand_default(self.rng())
@ -307,9 +304,8 @@ class SparsifyTest(jtu.JaxTestCase):
self.assertAllClose(result_sparse, result_dense)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_shapes={shapes}_func={func}_nbatch={n_batch}",
"shapes": shapes, "func": func, "n_batch": n_batch}
@jtu.sample_product(
[dict(shapes=shapes, func=func, n_batch=n_batch)
for shapes, func, n_batch in [
([(4,), (4,)], "concatenate", 0),
([(4,), (4,)], "stack", 0),
@ -328,7 +324,9 @@ class SparsifyTest(jtu.JaxTestCase):
([(2, 4), (2, 5)], "hstack", 2),
([(2, 4), (4,), (3, 4)], "vstack", 0),
([(1, 4), (4,), (1, 4)], "vstack", 0),
]))
]
],
)
def testSparseConcatenate(self, shapes, func, n_batch):
f = self.sparsify(getattr(jnp, func))
rng = jtu.rand_some_zero(self.rng())
@ -336,9 +334,8 @@ class SparsifyTest(jtu.JaxTestCase):
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
self.assertArraysEqual(f(arrs), f(sparrs).todense())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}",
"shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense}
@jtu.sample_product(
[dict(shape=shape, new_shape=new_shape, n_batch=n_batch, n_dense=n_dense)
for shape, new_shape, n_batch, n_dense in [
[(6,), (2, 3), 0, 0],
[(1, 4), (2, 2), 0, 0],
@ -348,7 +345,9 @@ class SparsifyTest(jtu.JaxTestCase):
[(2, 3, 4), (3, 8), 0, 0],
[(2, 3, 4), (1, 2, 12), 1, 0],
[(2, 3, 4), (6, 2, 2), 2, 0],
]))
]
],
)
def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense):
rng = jtu.rand_some_zero(self.rng())
arr = rng(shape, 'int32')
@ -359,10 +358,9 @@ class SparsifyTest(jtu.JaxTestCase):
self.assertArraysEqual(arr2, arr2_sparse.todense())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}_dimensions={dimensions}",
"shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense,
"dimensions": dimensions}
@jtu.sample_product(
[dict(shape=shape, new_shape=new_shape, n_batch=n_batch, n_dense=n_dense,
dimensions=dimensions)
for shape, new_shape, n_batch, n_dense, dimensions in [
[(2, 3, 4), (24,), 0, 0, None],
[(2, 3, 4), (24,), 0, 0, (0, 1, 2)],
@ -375,7 +373,9 @@ class SparsifyTest(jtu.JaxTestCase):
[(4, 2, 3), (2, 2, 6), 1, 0, (0, 2, 1)],
[(2, 3, 4), (6, 4), 2, 0, (0, 1, 2)],
[(2, 3, 4), (6, 4), 2, 0, (1, 0, 2)],
]))
]
],
)
def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch, n_dense, dimensions):
rng = jtu.rand_some_zero(self.rng())
arr = rng(shape, 'int32')

View File

@ -15,7 +15,6 @@
"""Tests for Stax library."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
@ -50,105 +49,81 @@ def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
@jtu.with_config(jax_numpy_rank_promotion="allow")
class StaxTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_shape={shape}", "shape": shape}
for shape in [(2, 3), (5,)]))
@jtu.sample_product(shape=[(2, 3), (5,)])
def testRandnInitShape(self, shape):
key = random.PRNGKey(0)
out = stax.randn()(key, shape)
self.assertEqual(out.shape, shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_shape={shape}", "shape": shape}
for shape in [(2, 3), (2, 3, 4)]))
@jtu.sample_product(shape=[(2, 3), (2, 3, 4)])
def testGlorotInitShape(self, shape):
key = random.PRNGKey(0)
out = stax.glorot()(key, shape)
self.assertEqual(out.shape, shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
.format(channels, filter_shape, padding, strides, input_shape),
"channels": channels, "filter_shape": filter_shape, "padding": padding,
"strides": strides, "input_shape": input_shape}
for channels in [2, 3]
for filter_shape in [(1, 1), (2, 3)]
for padding in ["SAME", "VALID"]
for strides in [None, (2, 1)]
for input_shape in [(2, 10, 11, 1)]))
@jtu.sample_product(
channels=[2, 3],
filter_shape=[(1, 1), (2, 3)],
padding=["SAME", "VALID"],
strides=[None, (2, 1)],
input_shape=[(2, 10, 11, 1)],
)
def testConvShape(self, channels, filter_shape, padding, strides,
input_shape):
init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides,
padding=padding)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
.format(channels, filter_shape, padding, strides, input_shape),
"channels": channels, "filter_shape": filter_shape, "padding": padding,
"strides": strides, "input_shape": input_shape}
for channels in [2, 3]
for filter_shape in [(1, 1), (2, 3), (3, 3)]
for padding in ["SAME", "VALID"]
for strides in [None, (2, 1), (2, 2)]
for input_shape in [(2, 10, 11, 1)]))
@jtu.sample_product(
channels=[2, 3],
filter_shape=[(1, 1), (2, 3), (3, 3)],
padding=["SAME", "VALID"],
strides=[None, (2, 1), (2, 2)],
input_shape=[(2, 10, 11, 1)],
)
def testConvTransposeShape(self, channels, filter_shape, padding, strides,
input_shape):
init_fun, apply_fun = stax.ConvTranspose(channels, filter_shape, # 2D
strides=strides, padding=padding)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
.format(channels, filter_shape, padding, strides, input_shape),
"channels": channels, "filter_shape": filter_shape, "padding": padding,
"strides": strides, "input_shape": input_shape}
for channels in [2, 3]
for filter_shape in [(1,), (2,), (3,)]
for padding in ["SAME", "VALID"]
for strides in [None, (1,), (2,)]
for input_shape in [(2, 10, 1)]))
@jtu.sample_product(
channels=[2, 3],
filter_shape=[(1,), (2,), (3,)],
padding=["SAME", "VALID"],
strides=[None, (1,), (2,)],
input_shape=[(2, 10, 1)],
)
def testConv1DTransposeShape(self, channels, filter_shape, padding, strides,
input_shape):
init_fun, apply_fun = stax.Conv1DTranspose(channels, filter_shape,
strides=strides, padding=padding)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_out_dim={}_input_shape={}"
.format(out_dim, input_shape),
"out_dim": out_dim, "input_shape": input_shape}
for out_dim in [3, 4]
for input_shape in [(2, 3), (3, 4)]))
@jtu.sample_product(
out_dim=[3, 4],
input_shape=[(2, 3), (3, 4)],
)
def testDenseShape(self, out_dim, input_shape):
init_fun, apply_fun = stax.Dense(out_dim)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_input_shape={}_nonlinear={}"
.format(input_shape, nonlinear),
"input_shape": input_shape, "nonlinear": nonlinear}
for input_shape in [(2, 3), (2, 3, 4)]
for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"]))
@jtu.sample_product(
input_shape=[(2, 3), (2, 3, 4)],
nonlinear=["Relu", "Sigmoid", "Elu", "LeakyRelu"],
)
def testNonlinearShape(self, input_shape, nonlinear):
init_fun, apply_fun = getattr(stax, nonlinear)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}"
"_maxpool={}_spec={}"
.format(window_shape, padding, strides, input_shape,
max_pool, spec),
"window_shape": window_shape, "padding": padding, "strides": strides,
"input_shape": input_shape, "max_pool": max_pool, "spec": spec}
for window_shape in [(1, 1), (2, 3)]
for padding in ["VALID"]
for strides in [None, (2, 1)]
for input_shape in [(2, 5, 6, 4)]
for max_pool in [False, True]
for spec in ["NHWC", "NCHW", "WHNC", "WHCN"]))
@jtu.sample_product(
window_shape=[(1, 1), (2, 3)],
padding=["VALID"],
strides=[None, (2, 1)],
input_shape=[(2, 5, 6, 4)],
max_pool=[False, True],
spec=["NHWC", "NCHW", "WHNC", "WHCN"],
)
def testPoolingShape(self, window_shape, padding, strides, input_shape,
max_pool, spec):
layer = stax.MaxPool if max_pool else stax.AvgPool
@ -156,49 +131,41 @@ class StaxTest(jtu.JaxTestCase):
spec=spec)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_shape={input_shape}",
"input_shape": input_shape}
for input_shape in [(2, 3), (2, 3, 4)]))
@jtu.sample_product(input_shape=[(2, 3), (2, 3, 4)])
def testFlattenShape(self, input_shape):
init_fun, apply_fun = stax.Flatten
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_input_shape={input_shape}_spec={i}",
"input_shape": input_shape, "spec": spec}
for input_shape in [(2, 5, 6, 1)]
for i, spec in enumerate([
[stax.Conv(3, (2, 2))],
[stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]])))
@jtu.sample_product(
input_shape=[(2, 5, 6, 1)],
spec=[
[stax.Conv(3, (2, 2))],
[stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)],
],
)
def testSerialComposeLayersShape(self, input_shape, spec):
init_fun, apply_fun = stax.serial(*spec)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_input_shape={input_shape}",
"input_shape": input_shape}
for input_shape in [(3, 4), (2, 5, 6, 1)]))
@jtu.sample_product(input_shape=[(3, 4), (2, 5, 6, 1)])
def testDropoutShape(self, input_shape):
init_fun, apply_fun = stax.Dropout(0.9)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_input_shape={input_shape}",
"input_shape": input_shape}
for input_shape in [(3, 4), (2, 5, 6, 1)]))
@jtu.sample_product(input_shape=[(3, 4), (2, 5, 6, 1)])
def testFanInSum(self, input_shape):
init_fun, apply_fun = stax.FanInSum
_CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_inshapes={input_shapes}_axis={axis}",
"input_shapes": input_shapes, "axis": axis}
@jtu.sample_product(
[dict(input_shapes=input_shapes, axis=axis)
for input_shapes, axis in [
([(2, 3), (2, 1)], 1),
([(2, 3), (2, 1)], -1),
([(1, 2, 4), (1, 1, 4)], 1),
]))
]
],
)
def testFanInConcat(self, input_shapes, axis):
init_fun, apply_fun = stax.FanInConcat(axis)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)

View File

@ -46,7 +46,7 @@ from jax._src.nn import initializers as nn_initializers
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2, prod, safe_zip
from jax._src.util import unzip2, prod, safe_zip
from jax._src.lax import parallel as lax_parallel
from jax._src.lax.parallel import pgather
from jax.interpreters import batching, pxla
@ -876,13 +876,12 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
class NamedNumPyTest(XMapTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{reduction.__name__}_axes={axes}_i={mapped_axis}",
"reduction": reduction, "axes": axes, "mapped_axis": mapped_axis}
for reduction in (jnp.sum, jnp.max, jnp.min, jnp.mean, jnp.var, jnp.std,
jscipy.special.logsumexp)
for axes in (0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0))
for mapped_axis in range(3)))
@jtu.sample_product(
reduction=(jnp.sum, jnp.max, jnp.min, jnp.mean, jnp.var, jnp.std,
jscipy.special.logsumexp),
axes=(0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0)),
mapped_axis=range(3),
)
def testReductions(self, reduction, axes, mapped_axis):
axes_t = axes if isinstance(axes, tuple) else (axes,)
ref_red = partial(reduction,
@ -901,25 +900,15 @@ class NamedNumPyTest(XMapTestCase):
class NamedRandomTest(XMapTestCase):
@curry
def parameterize_by_sampler(extra, f, subset):
if extra is None:
extra = [("", {})]
else:
extra = list(extra)
subset_fn = jtu.cases_from_list if subset else lambda x: x
return parameterized.named_parameters(subset_fn(
{"testcase_name": name + extra_name, "distr_sample": sample, **extra_kwargs}
for name, sample in [
("Uniform", jax.random.uniform),
("Normal", jax.random.normal),
("Bernoulli", partial(jax.random.bernoulli, p=0.5)),
("TruncatedNormal", partial(jax.random.truncated_normal, lower=-2, upper=2)),
]
for extra_name, extra_kwargs in extra))(f)
SAMPLERS = [
("Uniform", jax.random.uniform),
("Normal", jax.random.normal),
("Bernoulli", partial(jax.random.bernoulli, p=0.5)),
("TruncatedNormal", partial(jax.random.truncated_normal, lower=-2, upper=2)),
]
@parameterize_by_sampler(None, subset=False)
def testSamplerSharding(self, distr_sample):
@parameterized.parameters(*SAMPLERS)
def testSamplerSharding(self, distr_name, distr_sample):
def sample(shape, map_size):
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=shape),
in_axes=(), out_axes=[None, 'i', ...], axis_sizes={'i': map_size})()
@ -931,12 +920,15 @@ class NamedRandomTest(XMapTestCase):
with self.assertRaisesRegex(ValueError, error):
sample(NamedShape(3, i=4), 5)
@parameterize_by_sampler(
((f"_mesh={mesh}_resources={sorted(axis_resources.items())}",
{"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)})
for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True)
@jtu.sample_product(
[dict(distr_name=name, distr_sample=sample)
for name, sample in SAMPLERS],
[dict(axis_resources=tuple(axis_resources.items()), mesh=tuple(mesh))
for axis_resources, mesh in schedules({'i': 4, 'j': 6})],
)
@jtu.with_mesh_from_kwargs
def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh):
def testSamplerResourceIndependence(self, distr_name, distr_sample,
axis_resources, mesh):
def sample(axis_resources):
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)),
in_axes=(), out_axes=['i', 'j', ...], axis_sizes={'i': 4, 'j': 6},
@ -965,13 +957,12 @@ class NamedNNTest(XMapTestCase):
with self.assertRaisesRegex(ValueError, "to match the size of axis i, but 3 != 5"):
f(jnp.ones(5, dtype='int32'))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_map_in={map_in}_map_out={map_out}_fan={fan}_distr={distr}",
"map_in": map_in, "map_out": map_out, "fan": fan,
"distr": distr}
for map_in, map_out in [(True, False), (False, True), (True, True)]
for fan in ['fan_in', 'fan_out', 'fan_avg']
for distr in ['uniform', 'normal', 'truncated_normal']))
@jtu.sample_product(
[dict(map_in=map_in, map_out=map_out)
for map_in, map_out in [(True, False), (False, True), (True, True)]],
fan=['fan_in', 'fan_out', 'fan_avg'],
distr=['uniform', 'normal', 'truncated_normal'],
)
def testVarianceScaling(self, map_in, map_out, fan, distr):
shape = (80, 50, 7)
fan_in, fan_out = nn_initializers._compute_fans(NamedShape(*shape), 0, 1)