mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
This commit is contained in:
parent
9bb2c999d6
commit
72f4f389be
@ -658,20 +658,6 @@ def assert_dot_precision(expected_precision, fun, *args):
|
||||
assert precision == expected_precision, msg
|
||||
|
||||
|
||||
_CACHED_INDICES: Dict[int, Sequence[int]] = {}
|
||||
|
||||
def cases_from_list(xs):
|
||||
xs = list(xs)
|
||||
n = len(xs)
|
||||
k = min(n, FLAGS.num_generated_cases)
|
||||
# Random sampling for every parameterized test is expensive. Do it once and
|
||||
# cache the result.
|
||||
indices = _CACHED_INDICES.get(n)
|
||||
if indices is None:
|
||||
rng = npr.RandomState(42)
|
||||
_CACHED_INDICES[n] = indices = rng.permutation(n)
|
||||
return [xs[i] for i in indices[:k]]
|
||||
|
||||
def cases_from_gens(*gens):
|
||||
sizes = [1, 3, 10]
|
||||
cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1
|
||||
@ -703,14 +689,20 @@ def named_cases_from_sampler(gen):
|
||||
yield case
|
||||
|
||||
|
||||
# Random sampling for every parameterized test is expensive. Do it once and
|
||||
# cache the result.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _choice(n, m):
|
||||
rng = np.random.RandomState(42)
|
||||
return rng.choice(n, size=m, replace=False)
|
||||
|
||||
def sample_product_testcases(*args, **kw):
|
||||
"""Non-decorator form of sample_product."""
|
||||
args = [list(arg) for arg in args]
|
||||
kw = [(k, list(v)) for k, v in kw.items()]
|
||||
n = prod(len(a) for a in args) * prod(len(v) for _, v in kw)
|
||||
rng = np.random.RandomState(42)
|
||||
testcases = []
|
||||
for i in rng.choice(n, size=min(n, FLAGS.num_generated_cases), replace=False):
|
||||
for i in _choice(n, min(n, FLAGS.num_generated_cases)):
|
||||
testcase = {}
|
||||
for a in args:
|
||||
testcase.update(a[i % len(a)])
|
||||
|
@ -21,7 +21,6 @@ import unittest
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import ad_checkpoint
|
||||
@ -200,12 +199,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
self.assertIsNotNone(f_tf.get_concrete_function())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_dtype={dtype.__name__}_function={with_function}",
|
||||
dtype=dtype,
|
||||
with_function=with_function)
|
||||
for dtype in [np.int64, np.float64]
|
||||
for with_function in [True, False]))
|
||||
@jtu.sample_product(
|
||||
dtype=[np.int64, np.float64],
|
||||
with_function=[True, False],
|
||||
)
|
||||
def test_converts_64bit(self, dtype=np.int64, with_function=False):
|
||||
if not config.jax_enable_x64:
|
||||
self.skipTest("requires x64 mode")
|
||||
@ -251,10 +248,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
|
||||
self.ConvertAndCompare(f_jax, 0.7)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_gradients_disabled(self, with_function=False):
|
||||
f_tf = jax2tf.convert(jnp.tan, with_gradient=False)
|
||||
if with_function:
|
||||
@ -270,10 +264,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
y = f_tf(x)
|
||||
_ = tape.gradient(y, x)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_gradients(self, with_function=True):
|
||||
def f(x, y):
|
||||
return x * x, x * y
|
||||
@ -291,10 +282,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(5., tape.gradient(v, x))
|
||||
self.assertAllClose(4., tape.gradient(v, y))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_gradients_pytree(self, with_function=True):
|
||||
def f(xy: Tuple[float, float]) -> Dict[str, float]:
|
||||
x, y = xy
|
||||
@ -368,10 +356,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(grad_jax.b, grad_tf[1])
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_gradients_with_ordered_dict_input(self, with_function=True):
|
||||
def f(inputs):
|
||||
out = 0.0
|
||||
@ -394,10 +379,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy())
|
||||
self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_gradients_with_custom_jvp(self, with_function=True):
|
||||
"""Check gradients, for a function with custom JVP."""
|
||||
@jax.custom_jvp
|
||||
@ -428,10 +410,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(4. * 4., y)
|
||||
self.assertAllClose(3. * 4., tape.gradient(y, x))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_gradients_with_custom_vjp(self, with_function=True):
|
||||
"""Check gradients, for a function with custom VJP."""
|
||||
@jax.custom_vjp
|
||||
@ -498,10 +477,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.int32),
|
||||
d_dx_tf.numpy())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_gradients_unused_argument_readme(self, with_function=False):
|
||||
# x1 and x3 are not used. x3 has integer type.
|
||||
def fn(x0, x1, x2, x3):
|
||||
@ -546,10 +522,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
|
||||
self.assertAllClose(g_jax2tf[3].numpy(), np.int32(0))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_gradients_int_argument(self, with_function=False):
|
||||
# https://github.com/google/jax/issues/6975
|
||||
# Also issue #6975.
|
||||
@ -875,9 +848,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
jax2tf.convert(outer)(2.)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{transform}", transform=transform)
|
||||
for transform in ["jit", "jvp", "grad", "vmap"]))
|
||||
@jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"])
|
||||
def test_convert_under_transform_error(self, transform="vmap"):
|
||||
def outer(y):
|
||||
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
|
||||
@ -886,9 +857,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
ValueError, "convert must be used outside all JAX transformations"):
|
||||
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{transform}", transform=transform)
|
||||
for transform in ["jit", "jvp", "grad", "vmap"]))
|
||||
@jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"])
|
||||
def test_convert_under_transform_error_non_tracer(self, transform="vmap"):
|
||||
def outer(y):
|
||||
sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg
|
||||
@ -997,10 +966,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(tf_fn(tf.constant(1.375, tf.bfloat16)).numpy(),
|
||||
jnp.bfloat16(2.750))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_kwargs(self, with_function=True):
|
||||
# Re: https://github.com/google/jax/issues/6791
|
||||
def f_jax(*, x):
|
||||
@ -1012,10 +978,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf(x=np.zeros(3, dtype=np.float32)), # Call with kwargs.
|
||||
np.zeros((), dtype=np.float32))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_grad_kwargs(self, with_function=False):
|
||||
# Re: https://github.com/google/jax/issues/6791
|
||||
x = (np.zeros(3, dtype=np.float32),
|
||||
|
@ -294,29 +294,19 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
f_jax = jax.jit(lambda i: params[i])
|
||||
self.ConvertAndCompare(f_jax, indices)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
|
||||
for f_jax in REDUCE))
|
||||
@jtu.sample_product(f_jax=REDUCE)
|
||||
def test_reduce_ops_with_numerical_input(self, f_jax):
|
||||
values = np.array([1, 2, 3], dtype=np.float32)
|
||||
self.ConvertAndCompare(f_jax, values)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{op}", op=op) for op in (
|
||||
"add", "max", "min", "multiply", "set"
|
||||
)))
|
||||
@jtu.sample_product(op=["add", "max", "min", "multiply", "set"])
|
||||
def test_scatter_static(self, op):
|
||||
values = np.ones((5, 6), dtype=np.float32)
|
||||
update = np.float32(6.)
|
||||
f_jax = jax.jit(lambda v, u: getattr(v.at[::2, 3:], op)(u))
|
||||
self.ConvertAndCompare(f_jax, values, update)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
|
||||
for f_jax in REDUCE))
|
||||
@jtu.sample_product(f_jax=REDUCE)
|
||||
def test_reduce_ops_with_boolean_input(self, f_jax):
|
||||
values = np.array([True, False, True], dtype=np.bool_)
|
||||
self.ConvertAndCompare(f_jax, values)
|
||||
|
@ -834,10 +834,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
res_jax,
|
||||
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
def test_grad_int(self, with_function=True):
|
||||
# https://github.com/google/jax/issues/7093
|
||||
# Also issue #6975.
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
@ -67,16 +67,12 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.skipTest("DLPack not supported on TPU")
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_take_ownership={}_gpu={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
take_ownership, gpu),
|
||||
"shape": shape, "dtype": dtype, "take_ownership": take_ownership,
|
||||
"gpu": gpu}
|
||||
for shape in all_shapes
|
||||
for dtype in dlpack_dtypes
|
||||
for take_ownership in [False, True]
|
||||
for gpu in [False, True]))
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
take_ownership=[False, True],
|
||||
gpu=[False, True],
|
||||
)
|
||||
@jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973
|
||||
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -95,12 +91,10 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
"DLPack tensor may be consumed at most once",
|
||||
lambda: jax.dlpack.from_dlpack(dlpack))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in dlpack_dtypes))
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
@jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973
|
||||
def testTensorFlowToJax(self, shape, dtype):
|
||||
@ -121,12 +115,10 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
y = jax.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in dlpack_dtypes))
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
def testJaxToTensorFlow(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
|
||||
@ -145,12 +137,10 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
y = tf.experimental.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y.numpy())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in torch_dtypes))
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=torch_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testTorchToJax(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
|
||||
@ -177,12 +167,10 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
y, client)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in torch_dtypes))
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=torch_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testJaxToTorch(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
|
||||
@ -194,12 +182,10 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
y = torch.utils.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y.cpu().numpy())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in torch_dtypes))
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=torch_dtypes,
|
||||
)
|
||||
@unittest.skipIf(numpy_version < (1, 22, 0), "Requires numpy 1.22 or newer")
|
||||
@jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973
|
||||
def testNumpyToJax(self, shape, dtype):
|
||||
@ -208,12 +194,10 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
x_jax = jnp.from_dlpack(x_np)
|
||||
self.assertAllClose(x_np, x_jax)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in torch_dtypes))
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=torch_dtypes,
|
||||
)
|
||||
@unittest.skipIf(numpy_version < (1, 22, 0), "Requires numpy 1.22 or newer")
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def testJaxToNumpy(self, shape, dtype):
|
||||
@ -230,12 +214,10 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
if jtu.device_under_test() != "gpu":
|
||||
self.skipTest("__cuda_array_interface__ is only supported on GPU")
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in dlpack_dtypes))
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not cupy, "Test requires CuPy")
|
||||
def testJaxToCuPy(self, shape, dtype):
|
||||
if dtype == jnp.bfloat16:
|
||||
|
@ -24,7 +24,6 @@ import unittest
|
||||
from unittest import skip, SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import ad_checkpoint
|
||||
@ -514,12 +513,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(
|
||||
len(local_devices()), len(re.findall(r"112", testing_stream.output)))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=f"_with_jit_{with_jit}",
|
||||
with_jit=with_jit)
|
||||
for with_jit in [True, False]))
|
||||
@jtu.sample_product(with_jit=[True, False])
|
||||
def test_tap_pytree(self, with_jit=False):
|
||||
def func(x, what=""):
|
||||
"""Returns some pytrees depending on x"""
|
||||
@ -551,12 +545,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait() # Wait for receivers to be done
|
||||
self.assertEqual(3, tap_count)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=f"_concurrent_{concurrent}",
|
||||
concurrent=concurrent)
|
||||
for concurrent in [True, False]))
|
||||
@jtu.sample_product(concurrent=[True, False])
|
||||
def test_tap_multiple(self, concurrent=False):
|
||||
"""Call id_tap multiple times, concurrently or in sequence. """
|
||||
if concurrent and jtu.device_under_test() in ["cpu", "gpu"]:
|
||||
@ -628,12 +617,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=f"_with_jit_{with_jit}",
|
||||
with_jit=with_jit)
|
||||
for with_jit in [True, False]))
|
||||
@jtu.sample_product(with_jit=[True, False])
|
||||
def test_tap_cond(self, with_jit=False):
|
||||
"""A conditional"""
|
||||
|
||||
@ -663,11 +647,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
where: end
|
||||
4""", testing_stream.output)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(testcase_name=f"_with_jit_{with_jit}",
|
||||
with_jit=with_jit)
|
||||
for with_jit in [True, False]))
|
||||
@jtu.sample_product(with_jit=[True, False])
|
||||
def test_tap_while_cond(self, with_jit=False):
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
||||
@ -741,12 +721,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
where: 3
|
||||
3""", testing_stream.output)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=f"_with_jit_{with_jit}",
|
||||
with_jit=with_jit)
|
||||
for with_jit in [True, False]))
|
||||
@jtu.sample_product(with_jit=[True, False])
|
||||
def test_tap_scan_cond(self, with_jit=True):
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
||||
@ -796,15 +771,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
[1 2 3]""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=f"_shape_{shape}_dtype_{np.dtype(dtype).name}_nr_args={nr_args}",
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
nr_args=nr_args) for nr_args in [1, 2]
|
||||
for shape in [(), (2,), (2, 3), (2, 3, 4)]
|
||||
for dtype in jtu.dtypes.all))
|
||||
@jtu.sample_product(
|
||||
nr_args=[1, 2],
|
||||
shape=[(), (2,), (2, 3), (2, 3, 4)],
|
||||
dtype=jtu.dtypes.all,
|
||||
)
|
||||
def test_tap_jit_dtypes(self, nr_args=2, dtype=jnp.int16, shape=(2,)):
|
||||
if dtype in (jnp.complex64, jnp.complex128, jnp.bool_):
|
||||
raise SkipTest(f"host_callback not implemented for {dtype}.")
|
||||
@ -1971,13 +1942,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
10"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(testcase_name=f"_use_remat={use_remat}_{grad_func}_use_result={use_result}",
|
||||
use_result=use_result, use_remat=use_remat, grad_func=grad_func)
|
||||
for use_result in [True, False]
|
||||
for grad_func in ["grad", "value_and_grad"]
|
||||
for use_remat in ["old", "new", "none"]))
|
||||
@jtu.sample_product(
|
||||
use_result=[True, False],
|
||||
grad_func=["grad", "value_and_grad"],
|
||||
use_remat=["old", "new", "none"],
|
||||
)
|
||||
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
|
||||
if use_remat == "old": raise SkipTest()
|
||||
|
||||
@ -2108,11 +2077,9 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(2 * arg, fun(arg))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{np.dtype(dtype).name}", dtype=dtype)
|
||||
for dtype in jtu.dtypes.all
|
||||
if dtype != np.bool_))
|
||||
@jtu.sample_product(
|
||||
dtype=[dtype for dtype in jtu.dtypes.all if dtype != np.bool_],
|
||||
)
|
||||
def test_call_types(self, dtype=np.float64):
|
||||
|
||||
def f_outside(x):
|
||||
|
@ -133,7 +133,7 @@ LAX_GRAD_OPS = [
|
||||
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_complex_dtypes),
|
||||
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_float_dtypes, tol={np.float64: 3e-5}),
|
||||
dtypes=grad_float_dtypes, tol={np.float64: 5e-3}),
|
||||
grad_test_spec(lax.logistic, nargs=1, order=2,
|
||||
rng_factory=jtu.rand_default,
|
||||
dtypes=grad_inexact_dtypes),
|
||||
@ -195,38 +195,43 @@ def check_grads_bilinear(f, args, order,
|
||||
|
||||
class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.name, shapes, itertools.repeat(dtype)),
|
||||
"op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
|
||||
"order": rec.order, "tol": rec.tol}
|
||||
for shape_group in compatible_shapes
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(op=rec.op, rng_factory=rec.rng_factory, order=rec.order, tol=rec.tol)],
|
||||
shapes=[
|
||||
shapes for shape_group in compatible_shapes
|
||||
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
|
||||
for dtype in rec.dtypes)
|
||||
for rec in LAX_GRAD_OPS))
|
||||
],
|
||||
dtype=rec.dtypes,
|
||||
)
|
||||
for rec in LAX_GRAD_OPS
|
||||
))
|
||||
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
|
||||
rng = rng_factory(self.rng())
|
||||
if jtu.device_under_test() == "tpu" and op is lax.pow:
|
||||
raise SkipTest("pow grad imprecise on tpu")
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if op is lax.pow:
|
||||
raise SkipTest("pow grad imprecise on tpu")
|
||||
if op is lax.cos:
|
||||
order = 1 # 2nd-order gradient is imprecise on TPU.
|
||||
|
||||
tol = jtu.join_tolerance(1e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol
|
||||
args = tuple(rng(shape, dtype) for shape in shapes)
|
||||
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": f"_{rec.op.__name__}_{special_value}",
|
||||
"op": rec.op, "special_value": special_value, "tol": rec.tol}
|
||||
for special_value in rec.values)
|
||||
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(op=rec.op, tol=rec.tol)],
|
||||
special_value=rec.values,
|
||||
)
|
||||
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS
|
||||
))
|
||||
def testOpGradSpecialValue(self, op, special_value, tol):
|
||||
check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
|
||||
jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
|
||||
"from_dtype": from_dtype, "to_dtype": to_dtype}
|
||||
for from_dtype, to_dtype in itertools.product(inexact_dtypes, repeat=2)))
|
||||
@jtu.sample_product(
|
||||
from_dtype=inexact_dtypes,
|
||||
to_dtype=inexact_dtypes,
|
||||
)
|
||||
def testConvertElementTypeGrad(self, from_dtype, to_dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
|
||||
@ -237,12 +242,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
convert_element_type)
|
||||
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(), (2, 3)]
|
||||
for dtype in grad_float_dtypes))
|
||||
@jtu.sample_product(
|
||||
shape=[(), (2, 3)],
|
||||
dtype=grad_float_dtypes,
|
||||
)
|
||||
def testClampGrad(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
@ -253,15 +256,14 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2)
|
||||
check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
||||
dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name,
|
||||
num_arrs),
|
||||
"dim": dim, "base_shape": base_shape, "dtype": dtype, "num_arrs": num_arrs}
|
||||
for num_arrs in [3]
|
||||
for dtype in float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(base_shape=base_shape, dim=dim)
|
||||
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
||||
for dim in range(len(base_shape))))
|
||||
for dim in range(len(base_shape))
|
||||
],
|
||||
num_arrs=[3],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
|
||||
@ -270,21 +272,17 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
concatenate = lambda *args: lax.concatenate(args, dim)
|
||||
check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
||||
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||
strides, padding),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding}
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides)
|
||||
for lhs_shape, rhs_shape, all_strides in itertools.chain(
|
||||
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
|
||||
for b, i, j in itertools.product([2, 3], repeat=3)],
|
||||
[((4, 2, 1), (3, 2, 1), [(1,)])])
|
||||
for strides in all_strides
|
||||
for dtype in float_dtypes
|
||||
for padding in ["VALID", "SAME"]))
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
padding=["VALID", "SAME"],
|
||||
)
|
||||
def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
lhs = rng(lhs_shape, dtype)
|
||||
@ -294,18 +292,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
||||
atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
|
||||
"rhs_dilation={}"
|
||||
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||
strides, padding, lhs_dil, rhs_dil),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
|
||||
"rhs_dil": rhs_dil}
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides,
|
||||
padding=padding, lhs_dil=lhs_dil, rhs_dil=rhs_dil)
|
||||
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
|
||||
itertools.chain(
|
||||
itertools.chain(
|
||||
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
|
||||
[((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
|
||||
[(1, 1), (2, 1)], [(1, 1)])
|
||||
@ -315,8 +306,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for strides in all_strides
|
||||
for rhs_dil in rhs_dils
|
||||
for lhs_dil in lhs_dils
|
||||
for dtype in float_dtypes
|
||||
for padding in all_pads))
|
||||
for padding in all_pads
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
|
||||
padding, lhs_dil, rhs_dil):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
@ -328,19 +321,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
||||
atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
|
||||
"rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
|
||||
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
|
||||
feature_group_count, batch_group_count),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
|
||||
"rhs_dil": rhs_dil, "dimension_numbers": dim_nums,
|
||||
"perms": perms, "feature_group_count": feature_group_count,
|
||||
"batch_group_count": batch_group_count}
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides,
|
||||
padding=padding, lhs_dil=lhs_dil, rhs_dil=rhs_dil,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)])
|
||||
for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
|
||||
([(b * batch_group_count, i * feature_group_count, 6, 7),
|
||||
@ -354,13 +339,17 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for strides in all_strides
|
||||
for rhs_dil in rhs_dils
|
||||
for lhs_dil in lhs_dils
|
||||
for dtype in grad_inexact_dtypes
|
||||
for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] +
|
||||
([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
|
||||
([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
|
||||
],
|
||||
[dict(dimension_numbers=dim_nums, perms=perms)
|
||||
for dim_nums, perms in [
|
||||
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
||||
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
|
||||
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]))
|
||||
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
||||
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
|
||||
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
|
||||
],
|
||||
dtype=grad_inexact_dtypes,
|
||||
)
|
||||
def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
|
||||
padding, lhs_dil, rhs_dil, dimension_numbers,
|
||||
perms, feature_group_count, batch_group_count):
|
||||
@ -386,13 +375,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype)),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype}
|
||||
for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)]
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(
|
||||
lhs_shape=[(2,), (3, 2)],
|
||||
rhs_shape=[(2,), (2, 4)],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testDotGrad(self, lhs_shape, rhs_shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = {np.float16: 1e-1, np.float32: 1e-4}
|
||||
@ -407,14 +394,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
s = str(jax.make_jaxpr(pullback)(gresult))
|
||||
assert "Precision.HIGHEST" in s
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
|
||||
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||
dimension_numbers),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"dimension_numbers": dimension_numbers}
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers)
|
||||
for lhs_shape, rhs_shape, dimension_numbers in [
|
||||
((3, 2), (2, 4), (([1], [0]), ([], []))),
|
||||
((3, 5), (2, 5), (([1], [1]), ([], []))),
|
||||
@ -423,7 +405,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))),
|
||||
((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))),
|
||||
]
|
||||
for dtype in float_dtypes))
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
|
||||
dimension_numbers):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
@ -438,46 +422,36 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
s = str(jax.make_jaxpr(pullback)(gresult))
|
||||
assert "Precision.HIGHEST" in s
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
|
||||
shape, np.dtype(dtype).name, broadcast_sizes),
|
||||
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes}
|
||||
for shape in [(), (2, 3)]
|
||||
for dtype in float_dtypes
|
||||
for broadcast_sizes in [(), (2,), (1, 2)]))
|
||||
@jtu.sample_product(
|
||||
shape=[(), (2, 3)],
|
||||
dtype=float_dtypes,
|
||||
broadcast_sizes=[(), (2,), (1, 2)],
|
||||
)
|
||||
def testBroadcastGrad(self, shape, dtype, broadcast_sizes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = (rng(shape, dtype),)
|
||||
broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
|
||||
check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
||||
jtu.format_shape_dtype_string(inshape, dtype),
|
||||
outshape, broadcast_dimensions),
|
||||
"inshape": inshape, "dtype": dtype, "outshape": outshape,
|
||||
"dimensions": broadcast_dimensions}
|
||||
@jtu.sample_product(
|
||||
[dict(inshape=inshape, outshape=outshape, dimensions=broadcast_dimensions)
|
||||
for inshape, outshape, broadcast_dimensions in [
|
||||
([2], [2, 2], [0]),
|
||||
([2], [2, 2], [1]),
|
||||
([2], [2, 3], [0]),
|
||||
([], [2, 3], []),
|
||||
]
|
||||
for dtype in float_dtypes))
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
operand = rng(inshape, dtype)
|
||||
broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
||||
check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}_perm={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
jtu.format_shape_dtype_string(out_shape, dtype),
|
||||
permutation),
|
||||
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
||||
"permutation": permutation}
|
||||
for dtype in float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, out_shape=out_shape, permutation=permutation)
|
||||
for arg_shape, out_shape, permutation in [
|
||||
[(3, 4), (12,), None],
|
||||
[(2, 1, 4), (8,), None],
|
||||
@ -488,23 +462,26 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
[(2, 1, 4), (8,), (2, 0, 1)],
|
||||
[(2, 2, 4), (2, 8), (0, 2, 1)],
|
||||
[(2, 2, 4), (2, 8), (2, 0, 1)],
|
||||
]))
|
||||
]
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
operand = rng(arg_shape, dtype)
|
||||
reshape = lambda x: lax.reshape(x, out_shape, permutation)
|
||||
check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_pads={}"
|
||||
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
|
||||
"shape": shape, "dtype": dtype, "pads": pads}
|
||||
for dtype in float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, pads=pads)
|
||||
for shape, paddings in [
|
||||
[(), [()]],
|
||||
((2, 3), [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]),
|
||||
]
|
||||
for pads in paddings))
|
||||
for pads in paddings
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testPadGrad(self, shape, dtype, pads):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
@ -546,14 +523,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# self.assertEqual(result, 0.0)
|
||||
self.assertAllClose(result, np.nan)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_predshape={}_argshapes={}".format(
|
||||
jtu.format_shape_dtype_string(pred_shape, np.bool_),
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype)),
|
||||
"pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype}
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, pred_shape=pred_shape)
|
||||
for arg_shape in [(), (3,), (2, 3)]
|
||||
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
||||
for dtype in float_dtypes))
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testSelectGrad(self, pred_shape, arg_shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
pred = rng(pred_shape, np.bool_)
|
||||
@ -562,13 +538,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
|
||||
check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
start_indices, limit_indices, strides),
|
||||
"shape": shape, "dtype": dtype, "starts": start_indices,
|
||||
"limits": limit_indices, "strides": strides}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, starts=start_indices, limits=limit_indices,
|
||||
strides=strides)
|
||||
for shape, start_indices, limit_indices, strides in [
|
||||
[(3,), (1,), (2,), None],
|
||||
[(7,), (4,), (7,), None],
|
||||
@ -581,44 +553,43 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
||||
[(3, 3, 5), (0, 2, 0), (3, 2, 5), (1, 2, 1)]
|
||||
]
|
||||
for dtype in float_dtypes))
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testSliceGrad(self, shape, dtype, starts, limits, strides):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
slice = lambda x: lax.slice(x, starts, limits, strides)
|
||||
check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
start_indices, size_indices),
|
||||
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
||||
"size_indices": size_indices}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, start_indices=start_indices, size_indices=size_indices)
|
||||
for shape, start_indices, size_indices in [
|
||||
[(3,), (1,), (1,)],
|
||||
[(5, 3), (1, 1), (3, 1)],
|
||||
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
||||
]
|
||||
for dtype in float_dtypes))
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
|
||||
check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
start_indices, update_shape),
|
||||
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
||||
"update_shape": update_shape}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, start_indices=start_indices, update_shape=update_shape)
|
||||
for shape, start_indices, update_shape in [
|
||||
[(3,), (1,), (1,)],
|
||||
[(5, 3), (1, 1), (3, 1)],
|
||||
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
||||
]
|
||||
for dtype in float_dtypes))
|
||||
def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape):
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices,
|
||||
update_shape):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
update = rng(update_shape, dtype)
|
||||
@ -664,28 +635,25 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
result2, _ = jax.value_and_grad(f, 0)(x, y)
|
||||
self.assertAllClose(result1, result2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_perm={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), perm),
|
||||
"shape": shape, "dtype": dtype, "perm": perm}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, perm=perm)
|
||||
for shape, perm in [
|
||||
[(3, 4), (1, 0)],
|
||||
[(3, 4), (0, 1)],
|
||||
[(3, 4, 5), (2, 1, 0)],
|
||||
[(3, 4, 5), (1, 0, 2)],
|
||||
]
|
||||
for dtype in float_dtypes))
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testTransposeGrad(self, shape, dtype, perm):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
transpose = lambda x: lax.transpose(x, perm)
|
||||
check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_inshape={}_reducedims={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
|
||||
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
|
||||
"dims": dims, "rng_factory": rng_factory}
|
||||
@jtu.sample_product(
|
||||
[dict(init_val=init_val, op=op, dtype=dtype, rng_factory=rng_factory)
|
||||
for init_val, op, dtypes, rng_factory in [
|
||||
(0, lax.add, float_dtypes + jtu.dtypes.complex, jtu.rand_default),
|
||||
(-np.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
|
||||
@ -693,6 +661,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
(1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
|
||||
]
|
||||
for dtype in dtypes
|
||||
],
|
||||
[dict(shape=shape, dims=dims)
|
||||
for shape, dims in [
|
||||
[(), ()],
|
||||
[(3, 4, 5), ()],
|
||||
@ -702,7 +672,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
[(3, 4, 5), (0, 1, 2)],
|
||||
[(3, 1), (1,)],
|
||||
[(3, 0, 5), (1,)],
|
||||
]))
|
||||
]
|
||||
],
|
||||
)
|
||||
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
if jtu.device_under_test() == "tpu" and op is lax.mul:
|
||||
@ -718,11 +690,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
|
||||
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_reducedims={}"
|
||||
.format(jtu.format_shape_dtype_string(shape, dtype), dims),
|
||||
"shape": shape, "dtype": dtype, "dims": dims}
|
||||
for dtype in grad_float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, dims=dims)
|
||||
for shape, dims in [
|
||||
[(3, 4, 5), ()],
|
||||
[(3, 4, 5), (0,)],
|
||||
@ -731,7 +700,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
[(3, 4, 5), (0, 1, 2)],
|
||||
[(3, 1), (1,)],
|
||||
[(3, 0, 5), (1,)],
|
||||
]))
|
||||
]
|
||||
],
|
||||
dtype=grad_float_dtypes,
|
||||
)
|
||||
def testReducePairGrad(self, shape, dtype, dims):
|
||||
rng = jtu.rand_default(self.rng(), scale=1)
|
||||
tol = {np.float32: 1e-2, np.float64: 1e-4}
|
||||
@ -742,20 +714,16 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims)
|
||||
check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
|
||||
"_basedilation={}_windowdilation={}")
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
|
||||
strides, padding, base_dilation, window_dilation),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
|
||||
"dims": dims, "strides": strides, "padding": padding,
|
||||
"base_dilation": base_dilation, "window_dilation": window_dilation,
|
||||
"rng_factory": rng_factory}
|
||||
@jtu.sample_product(
|
||||
[dict(init_val=init_val, op=op, dtype=dtype, rng_factory=rng_factory,
|
||||
shape=shape, dims=dims, strides=strides, padding=padding,
|
||||
base_dilation=base_dilation, window_dilation=window_dilation)
|
||||
for init_val, op, dtypes, rng_factory in [
|
||||
(0, lax.add, grad_float_dtypes, jtu.rand_small),
|
||||
(-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
|
||||
(np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
||||
itertools.chain(
|
||||
itertools.product(
|
||||
@ -772,7 +740,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
||||
[(1, 1, 1, 1)] + ([(2, 1, 3, 2)]),
|
||||
[(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else []))))
|
||||
for dtype in dtypes))
|
||||
],
|
||||
)
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Using reduced precision for gradient.*")
|
||||
def testReduceWindowGrad(
|
||||
@ -808,20 +777,20 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
|
||||
eps)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}_reverse={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
|
||||
reverse),
|
||||
"op": op, "shape": shape, "dtype": dtype,
|
||||
"axis": axis, "reverse": reverse}
|
||||
@jtu.sample_product(
|
||||
[dict(op=op, dtype=dtype)
|
||||
for op, types in [
|
||||
(lax.cumsum, [np.float32, np.float64]),
|
||||
(lax.cumprod, [np.float32, np.float64]),
|
||||
]
|
||||
for dtype in types
|
||||
],
|
||||
[dict(shape=shape, axis=axis)
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for reverse in [False, True]))
|
||||
],
|
||||
reverse=[False, True],
|
||||
)
|
||||
def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse):
|
||||
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small)
|
||||
@ -831,33 +800,30 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}_isstable={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
|
||||
"shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable}
|
||||
for dtype in [np.float32]
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, axis=axis)
|
||||
for shape in [(5,), (5, 7)]
|
||||
for axis in [len(shape) - 1]
|
||||
for is_stable in [False, True]))
|
||||
],
|
||||
dtype=[np.float32],
|
||||
is_stable=[False, True],
|
||||
)
|
||||
def testSortGrad(self, shape, dtype, axis, is_stable):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng = jtu.rand_unique_int(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
|
||||
check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)
|
||||
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format(
|
||||
jtu.format_shape_dtype_string(shape, key_dtype),
|
||||
jtu.format_shape_dtype_string(shape, val_dtype),
|
||||
axis, is_stable),
|
||||
"shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype,
|
||||
"axis": axis, "is_stable": is_stable}
|
||||
for key_dtype in [np.float32]
|
||||
for val_dtype in [np.float32]
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, axis=axis)
|
||||
for shape in [(3,), (5, 3)]
|
||||
for axis in [len(shape) - 1]
|
||||
for is_stable in [False, True]))
|
||||
],
|
||||
key_dtype=[np.float32],
|
||||
val_dtype=[np.float32],
|
||||
is_stable=[False, True],
|
||||
)
|
||||
def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
# This test relies on the property that wherever keys are tied, values are
|
||||
@ -873,30 +839,28 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
|
||||
check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_k={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), k),
|
||||
"shape": shape, "dtype": dtype, "k": k}
|
||||
for dtype in [np.float32,]
|
||||
for shape in [(4,), (5, 5), (2, 1, 4)]
|
||||
for k in [1, 3]))
|
||||
@jtu.sample_product(
|
||||
dtype=[np.float32,],
|
||||
shape=[(4,), (5, 5), (2, 1, 4)],
|
||||
k=[1, 3],
|
||||
)
|
||||
def testTopKGrad(self, shape, dtype, k):
|
||||
flat_values = np.arange(prod(shape), dtype=dtype)
|
||||
values = self.rng().permutation(flat_values).reshape(shape)
|
||||
fun = lambda vs: lax.top_k(vs, k=k)[0]
|
||||
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
|
||||
"shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes}
|
||||
for dtype in float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, idxs=idxs, axes=axes)
|
||||
for shape, idxs, axes in [
|
||||
[(3, 4, 5), (np.array([0, 2, 1]),), (0,)],
|
||||
[(3, 4, 5), (np.array([-1, -2]),), (0,)],
|
||||
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)],
|
||||
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)],
|
||||
]))
|
||||
]
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
@jax.numpy_rank_promotion('allow') # Test explicitly exercises implicit rank promotion.
|
||||
def testIndexTakeGrad(self, shape, dtype, idxs, axes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -904,16 +868,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
index_take = lambda src: lax.index_take(src, idxs, axes)
|
||||
check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_idxs={jtu.format_shape_dtype_string(idxs.shape, idxs.dtype)}"
|
||||
f"_dnums={dnums}_slice_sizes={slice_sizes}_mode={mode}"
|
||||
f"_iteration={iteration}",
|
||||
"shape": shape, "dtype": dtype, "idxs_shape": idxs.shape,
|
||||
"idxs_dtype": idxs.dtype, "dnums": dnums, "slice_sizes": slice_sizes,
|
||||
"max_idx": max_idx, "mode": mode}
|
||||
for dtype in grad_float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, idxs_shape=idxs.shape, idxs_dtype=idxs.dtype,
|
||||
dnums=dnums, slice_sizes=slice_sizes, max_idx=max_idx)
|
||||
for shape, idxs, dnums, slice_sizes, max_idx in [
|
||||
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
@ -925,10 +882,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3), 3),
|
||||
]
|
||||
for mode in ["clip", "fill", "promise_in_bounds"]
|
||||
for iteration in range(5)))
|
||||
],
|
||||
dtype=grad_float_dtypes,
|
||||
mode=["clip", "fill", "promise_in_bounds"],
|
||||
iteration=range(5),
|
||||
)
|
||||
def testGatherGrad(self, shape, dtype, idxs_shape, idxs_dtype, dnums,
|
||||
slice_sizes, mode, max_idx):
|
||||
slice_sizes, mode, max_idx, iteration):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
if mode == "promise_in_bounds":
|
||||
rng_idx = jtu.rand_int(self.rng(), high=max_idx)
|
||||
@ -945,16 +905,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
x = rng(shape, dtype)
|
||||
check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(arg_shape, dtype)}"
|
||||
f"_idxs={jtu.format_shape_dtype_string(idxs.shape, idxs.dtype)}"
|
||||
f"_update={update_shape}_dnums={dnums}_mode={mode}"
|
||||
f"_iteration={iteration}",
|
||||
"arg_shape": arg_shape, "dtype": dtype, "idxs_shape": idxs.shape,
|
||||
"idxs_dtype": idxs.dtype, "update_shape": update_shape, "dnums": dnums,
|
||||
"max_idx": max_idx, "mode": mode}
|
||||
for dtype in grad_float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, idxs_shape=idxs.shape, idxs_dtype=idxs.dtype,
|
||||
dnums=dnums, update_shape=update_shape, max_idx=max_idx)
|
||||
for arg_shape, idxs, update_shape, dnums, max_idx in [
|
||||
((5,), np.array([[0], [2]]), (2,),
|
||||
lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
@ -969,10 +922,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,)), 3),
|
||||
]
|
||||
for mode in ["clip", "fill", "promise_in_bounds"]
|
||||
for iteration in range(5)))
|
||||
],
|
||||
dtype=grad_float_dtypes,
|
||||
mode=["clip", "fill", "promise_in_bounds"],
|
||||
iteration=range(5),
|
||||
)
|
||||
def testScatterAddGrad(self, arg_shape, dtype, idxs_shape, idxs_dtype,
|
||||
update_shape, dnums, max_idx, mode):
|
||||
update_shape, dnums, max_idx, mode, iteration):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
if mode == "promise_in_bounds":
|
||||
rng_idx = jtu.rand_int(self.rng(), high=max_idx)
|
||||
@ -987,13 +943,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
y = rng(update_shape, dtype)
|
||||
check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
idxs, update_shape, dnums),
|
||||
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
||||
"update_shape": update_shape, "dnums": dnums, "rng_idx_factory": rng_idx_factory}
|
||||
for dtype in grad_float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums,
|
||||
update_shape=update_shape, max_idx=max_idx)
|
||||
for arg_shape, idxs, update_shape, dnums, max_idx in [
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
@ -1005,13 +957,15 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,)), 3),
|
||||
]
|
||||
# Scatters with conflicting indices are not deterministic on GPU, so we
|
||||
# use indices that do not collide.
|
||||
for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)]))
|
||||
],
|
||||
dtype=grad_float_dtypes,
|
||||
)
|
||||
def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
|
||||
rng_idx_factory):
|
||||
max_idx):
|
||||
# Scatters with conflicting indices are not deterministic on GPU, so we
|
||||
# use indices that do not collide.
|
||||
rng_idx = jtu.rand_unique_int(self.rng(), high=max_idx)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng_idx = rng_idx_factory(self.rng())
|
||||
idxs = rng_idx(idxs.shape, idxs.dtype)
|
||||
scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
|
||||
x = rng(arg_shape, dtype)
|
||||
@ -1028,13 +982,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
|
||||
1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
idxs, update_shape, dnums),
|
||||
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
||||
"update_shape": update_shape, "dnums": dnums}
|
||||
for dtype in grad_float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums,
|
||||
update_shape=update_shape)
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
@ -1045,7 +995,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]))
|
||||
]
|
||||
],
|
||||
dtype=grad_float_dtypes,
|
||||
)
|
||||
def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
|
||||
@ -1055,13 +1008,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
y = rng(update_shape, dtype)
|
||||
check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
idxs, update_shape, dnums),
|
||||
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
||||
"update_shape": update_shape, "dnums": dnums}
|
||||
for dtype in grad_float_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums,
|
||||
update_shape=update_shape)
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
@ -1072,7 +1021,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]))
|
||||
]
|
||||
],
|
||||
dtype=grad_float_dtypes,
|
||||
)
|
||||
def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
|
||||
|
@ -137,7 +137,7 @@ if numpy_version >= (1, 22): # where keyword added in numpy 1.22
|
||||
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
inexact=True, tolerance={np.float16: 3e-3}),
|
||||
op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
]
|
||||
@ -148,7 +148,7 @@ JAX_REDUCER_NO_DTYPE_RECORDS = [
|
||||
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
inexact=True, tolerance={jnp.bfloat16: 2e-2}),
|
||||
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("nanmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
||||
@ -179,23 +179,25 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
for a in out]
|
||||
return f
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis,
|
||||
"None" if out_dtype is None else np.dtype(out_dtype).name, keepdims),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
||||
[dict(shape=shape, axis=axis, dtype=dtype)
|
||||
for shape in rec.shapes
|
||||
for dtype in rec.dtypes
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_RECORDS))
|
||||
def testReducer(self, np_op, jnp_op, rng_factory, shape, dtype, out_dtype,
|
||||
if jtu.is_valid_shape(shape, dtype)
|
||||
],
|
||||
out_dtype=[out_dtype for out_dtype in [None] + rec.dtypes
|
||||
if out_dtype not in unsigned_dtypes],
|
||||
keepdims=[False, True],
|
||||
)
|
||||
for rec in JAX_REDUCER_RECORDS
|
||||
))
|
||||
def testReducer(self, name, rng_factory, shape, dtype, out_dtype,
|
||||
axis, keepdims, inexact):
|
||||
np_op = getattr(np, name)
|
||||
jnp_op = getattr(jnp, name)
|
||||
rng = rng_factory(self.rng())
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
@ -223,23 +225,26 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,
|
||||
tolerance=rec.tolerance)],
|
||||
[dict(shape=shape, axis=axis, dtype=dtype)
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_NO_DTYPE_RECORDS))
|
||||
def testReducerNoDtype(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
keepdims, inexact):
|
||||
if jtu.is_valid_shape(shape, dtype)
|
||||
],
|
||||
keepdims=[False, True],
|
||||
)
|
||||
for rec in JAX_REDUCER_NO_DTYPE_RECORDS
|
||||
))
|
||||
def testReducerNoDtype(self, name, rng_factory, shape, dtype, axis,
|
||||
keepdims, inexact, tolerance):
|
||||
np_op = getattr(np, name)
|
||||
jnp_op = getattr(jnp, name)
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
is_bf16_nan_test = (dtype == jnp.bfloat16 and
|
||||
rng_factory.__name__ == 'rand_some_nan')
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
@ -255,25 +260,28 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol = {np.float16: 0.002}
|
||||
tol = jtu.join_tolerance({np.float16: 0.002},
|
||||
tolerance or jtu.default_tolerance())
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
||||
[dict(shape=shape, axis=axis, dtype=dtype)
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for initial in [0, 1] for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_INITIAL_RECORDS))
|
||||
def testReducerInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
if jtu.is_valid_shape(shape, dtype)
|
||||
],
|
||||
initial=[0, 1],
|
||||
keepdims=[False, True],
|
||||
)
|
||||
for rec in JAX_REDUCER_INITIAL_RECORDS
|
||||
))
|
||||
def testReducerInitial(self, name, rng_factory, shape, dtype, axis,
|
||||
keepdims, initial, inexact):
|
||||
np_op = getattr(np, name)
|
||||
jnp_op = getattr(jnp, name)
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
@ -295,25 +303,27 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_promote_integers={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, promote_integers),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact,
|
||||
"promote_integers": promote_integers}
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
||||
[dict(shape=shape, axis=axis, dtype=dtype)
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for initial in [0, 1] for keepdims in [False, True]
|
||||
for promote_integers in [True, False]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_PROMOTE_INT_RECORDS))
|
||||
def testReducerPromoteInt(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
if jtu.is_valid_shape(shape, dtype)
|
||||
],
|
||||
initial=[0, 1],
|
||||
keepdims=[False, True],
|
||||
promote_integers=[False, True],
|
||||
)
|
||||
for rec in JAX_REDUCER_PROMOTE_INT_RECORDS
|
||||
))
|
||||
def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis,
|
||||
keepdims, initial, inexact, promote_integers):
|
||||
np_op = getattr(np, name)
|
||||
jnp_op = getattr(jnp, name)
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
is_bf16_nan_test = (dtype == jnp.bfloat16 and
|
||||
rng_factory.__name__ == 'rand_some_nan')
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@ -338,21 +348,22 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
||||
[dict(shape=shape, axis=axis)
|
||||
for shape in rec.shapes if np.prod(shape) == 0
|
||||
for dtype in rec.dtypes
|
||||
for keepdims in [False, True]
|
||||
for axis in range(-len(shape), len(shape)) if shape[axis] >= 1)
|
||||
for rec in JAX_REDUCER_INITIAL_RECORDS))
|
||||
def testReducerNoInitialZeroDims(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
for axis in range(-len(shape), len(shape)) if shape[axis] >= 1
|
||||
],
|
||||
dtype=rec.dtypes,
|
||||
keepdims=[False, True],
|
||||
)
|
||||
for rec in JAX_REDUCER_INITIAL_RECORDS
|
||||
))
|
||||
def testReducerNoInitialZeroDims(self, name, rng_factory, shape, dtype, axis,
|
||||
keepdims, inexact):
|
||||
np_op = getattr(np, name)
|
||||
jnp_op = getattr(jnp, name)
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
@ -374,23 +385,24 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial,
|
||||
jtu.format_shape_dtype_string(whereshape, bool)),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
|
||||
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
||||
[dict(shape=shape, axis=axis, dtype=dtype, whereshape=whereshape)
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for whereshape in _compatible_shapes(shape)
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for initial in [0, 1] for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_INITIAL_RECORDS))
|
||||
def testReducerWhere(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
if jtu.is_valid_shape(shape, dtype)
|
||||
for whereshape in _compatible_shapes(shape)
|
||||
],
|
||||
initial=[0, 1],
|
||||
keepdims=[False, True],
|
||||
)
|
||||
for rec in JAX_REDUCER_INITIAL_RECORDS
|
||||
))
|
||||
def testReducerWhere(self, name, rng_factory, shape, dtype, axis,
|
||||
keepdims, initial, inexact, whereshape):
|
||||
np_op = getattr(np, name)
|
||||
jnp_op = getattr(jnp, name)
|
||||
if (shape in [()] + scalar_shapes and
|
||||
dtype in [jnp.int16, jnp.uint16] and
|
||||
jnp_op in [jnp.min, jnp.max]):
|
||||
@ -417,23 +429,23 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims,
|
||||
jtu.format_shape_dtype_string(whereshape, bool)),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
|
||||
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for whereshape in _compatible_shapes(shape)
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS))
|
||||
def testReducerWhereNoInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
keepdims, inexact, whereshape):
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,
|
||||
tol=rec.tolerance)],
|
||||
[dict(shape=shape, axis=axis, dtype=dtype, whereshape=whereshape)
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for whereshape in _compatible_shapes(shape)
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
if jtu.is_valid_shape(shape, dtype)
|
||||
],
|
||||
keepdims=[False, True],
|
||||
) for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS
|
||||
))
|
||||
def testReducerWhereNoInitial(self, name, rng_factory, shape, dtype, axis,
|
||||
keepdims, inexact, whereshape, tol):
|
||||
np_op = getattr(np, name)
|
||||
jnp_op = getattr(jnp, name)
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16
|
||||
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
|
||||
@ -458,7 +470,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testReductionOfOutOfBoundsAxis(self): # Issue 888
|
||||
@ -510,19 +522,14 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
|
||||
.format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims),
|
||||
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
|
||||
"ddof": ddof, "keepdims": keepdims}
|
||||
for shape in [(5,), (10, 5)]
|
||||
for dtype in all_dtypes
|
||||
for out_dtype in inexact_dtypes
|
||||
for axis in [None, 0, -1]
|
||||
for ddof in [0, 1, 2]
|
||||
for keepdims in [False, True]))
|
||||
@jtu.sample_product(
|
||||
shape=[(5,), (10, 5)],
|
||||
dtype=all_dtypes,
|
||||
out_dtype=inexact_dtypes,
|
||||
axis=[None, 0, -1],
|
||||
ddof=[0, 1, 2],
|
||||
keepdims=[False, True],
|
||||
)
|
||||
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
@ -547,19 +554,14 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
|
||||
.format(shape, dtype, out_dtype, axis, ddof, keepdims),
|
||||
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
|
||||
"ddof": ddof, "keepdims": keepdims}
|
||||
for shape in [(5,), (10, 5)]
|
||||
for dtype in all_dtypes
|
||||
for out_dtype in inexact_dtypes
|
||||
for axis in [None, 0, -1]
|
||||
for ddof in [0, 1, 2]
|
||||
for keepdims in [False, True]))
|
||||
@jtu.sample_product(
|
||||
shape=[(5,), (10, 5)],
|
||||
dtype=all_dtypes,
|
||||
out_dtype=inexact_dtypes,
|
||||
axis=[None, 0, -1],
|
||||
ddof=[0, 1, 2],
|
||||
keepdims=[False, True],
|
||||
)
|
||||
def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
|
||||
rng = jtu.rand_some_nan(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
@ -594,22 +596,20 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
z = jax.grad(jnp.nanstd)(x)
|
||||
self.assertEqual(jnp.isnan(z).sum(), 0)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format(
|
||||
shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights),
|
||||
"shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype,"rowvar": rowvar, "ddof": ddof,
|
||||
"bias": bias, "fweights": fweights, "aweights": aweights}
|
||||
for shape in [(5,), (10, 5), (5, 10)]
|
||||
for dtype in all_dtypes
|
||||
for y_dtype in [None, dtype]
|
||||
for rowvar in [True, False]
|
||||
for y_shape in _get_y_shapes(y_dtype, shape, rowvar)
|
||||
for bias in [True, False]
|
||||
for ddof in [None, 2, 3]
|
||||
for fweights in [True, False]
|
||||
for aweights in [True, False]))
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, dtype=dtype, y_dtype=y_dtype, rowvar=rowvar,
|
||||
y_shape=y_shape)
|
||||
for shape in [(5,), (10, 5), (5, 10)]
|
||||
for dtype in all_dtypes
|
||||
for y_dtype in [None, dtype]
|
||||
for rowvar in [True, False]
|
||||
for y_shape in _get_y_shapes(y_dtype, shape, rowvar)
|
||||
],
|
||||
bias=[True, False],
|
||||
ddof=[None, 2, 3],
|
||||
fweights=[True, False],
|
||||
aweights=[True, False],
|
||||
)
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
||||
def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
@ -15,7 +15,6 @@
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
import scipy.sparse.linalg
|
||||
@ -93,15 +92,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
M = None
|
||||
return M
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
||||
for shape in [(4, 4), (7, 7)]
|
||||
for dtype in [np.float64, np.complex128]
|
||||
for preconditioner in [None, 'identity', 'exact', 'random']))
|
||||
@jtu.sample_product(
|
||||
shape=[(4, 4), (7, 7)],
|
||||
dtype=[np.float64, np.complex128],
|
||||
preconditioner=[None, 'identity', 'exact', 'random'],
|
||||
)
|
||||
def test_cg_against_scipy(self, shape, dtype, preconditioner):
|
||||
if not config.x64_enabled:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
@ -132,12 +127,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
args_maker,
|
||||
tol=1e-6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(2, 2)]
|
||||
for dtype in float_types + complex_types))
|
||||
@jtu.sample_product(
|
||||
shape=[(2, 2)],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def test_cg_as_solve(self, shape, dtype):
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -219,16 +212,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertTrue(dtypes.is_weakly_typed(x))
|
||||
|
||||
# BICGSTAB
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
||||
for shape in [(5, 5)]
|
||||
for dtype in [np.float64, np.complex128]
|
||||
for preconditioner in [None, 'identity', 'exact', 'random']
|
||||
))
|
||||
@jtu.sample_product(
|
||||
shape=[(5, 5)],
|
||||
dtype=[np.float64, np.complex128],
|
||||
preconditioner=[None, 'identity', 'exact', 'random'],
|
||||
)
|
||||
def test_bicgstab_against_scipy(
|
||||
self, shape, dtype, preconditioner):
|
||||
if not config.jax_enable_x64:
|
||||
@ -266,16 +254,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
args_maker,
|
||||
tol=1e-4)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
||||
for shape in [(2, 2), (7, 7)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
))
|
||||
@jtu.sample_product(
|
||||
shape=[(2, 2), (7, 7)],
|
||||
dtype=float_types + complex_types,
|
||||
preconditioner=[None, 'identity', 'exact'],
|
||||
)
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner):
|
||||
A = jnp.eye(shape[1], dtype=dtype)
|
||||
@ -290,17 +273,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
solution_tol = 1e-8 if using_x64 else 1e-4
|
||||
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner
|
||||
}
|
||||
for shape in [(2, 2), (4, 4)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
))
|
||||
@jtu.sample_product(
|
||||
shape=[(2, 2), (4, 4)],
|
||||
dtype=float_types + complex_types,
|
||||
preconditioner=[None, 'identity', 'exact'],
|
||||
)
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def test_bicgstab_on_random_system(self, shape, dtype, preconditioner):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -339,18 +316,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
# GMRES
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}_solve_method={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner,
|
||||
solve_method),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"solve_method": solve_method}
|
||||
for shape in [(3, 3)]
|
||||
for dtype in [np.float64, np.complex128]
|
||||
for preconditioner in [None, 'identity', 'exact', 'random']
|
||||
for solve_method in ['incremental', 'batched']))
|
||||
@jtu.sample_product(
|
||||
shape=[(3, 3)],
|
||||
dtype=[np.float64, np.complex128],
|
||||
preconditioner=[None, 'identity', 'exact', 'random'],
|
||||
solve_method=['incremental', 'batched'],
|
||||
)
|
||||
def test_gmres_against_scipy(
|
||||
self, shape, dtype, preconditioner, solve_method):
|
||||
if not config.x64_enabled:
|
||||
@ -380,27 +351,20 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
partial(scipy_gmres, M=M, restart=2, maxiter=1),
|
||||
partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method),
|
||||
args_maker,
|
||||
tol=1e-10)
|
||||
tol=1e-9)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
np.linalg.solve,
|
||||
partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method),
|
||||
args_maker,
|
||||
tol=1e-10)
|
||||
tol=1e-8)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}_solve_method={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner,
|
||||
solve_method),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"solve_method": solve_method}
|
||||
for shape in [(2, 2), (7, 7)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
for solve_method in ['batched', 'incremental']
|
||||
))
|
||||
@jtu.sample_product(
|
||||
shape=[(2, 2), (7, 7)],
|
||||
dtype=float_types + complex_types,
|
||||
preconditioner=[None, 'identity', 'exact'],
|
||||
solve_method=['batched', 'incremental'],
|
||||
)
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
|
||||
solve_method):
|
||||
@ -419,19 +383,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
solution_tol = 1e-8 if using_x64 else 1e-4
|
||||
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}_solve_method={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner,
|
||||
solve_method),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner,
|
||||
"solve_method": solve_method}
|
||||
for shape in [(2, 2), (4, 4)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity', 'exact']
|
||||
for solve_method in ['incremental', 'batched']
|
||||
))
|
||||
@jtu.sample_product(
|
||||
shape=[(2, 2), (4, 4)],
|
||||
dtype=float_types + complex_types,
|
||||
preconditioner=[None, 'identity', 'exact'],
|
||||
solve_method=['incremental', 'batched'],
|
||||
)
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def test_gmres_on_random_system(self, shape, dtype, preconditioner,
|
||||
solve_method):
|
||||
@ -469,15 +426,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_preconditioner={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
preconditioner),
|
||||
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
|
||||
for shape in [(2, 2), (3, 3)]
|
||||
for dtype in float_types + complex_types
|
||||
for preconditioner in [None, 'identity']))
|
||||
@jtu.sample_product(
|
||||
shape=[(2, 2), (3, 3)],
|
||||
dtype=float_types + complex_types,
|
||||
preconditioner=[None, 'identity'],
|
||||
)
|
||||
def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
|
||||
"""
|
||||
The Arnoldi decomposition within GMRES is correct.
|
||||
|
@ -141,7 +141,8 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
# of inputs to some reasonable intervals
|
||||
op_record("zeta", 2, float_dtypes, jtu.rand_positive, False),
|
||||
# TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise
|
||||
op_record("expi", 1, [np.float32], jtu.rand_default, True),
|
||||
op_record("expi", 1, [np.float32], partial(jtu.rand_not_small, offset=0.1),
|
||||
True),
|
||||
op_record("exp1", 1, [np.float32], jtu.rand_positive, True),
|
||||
op_record("expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)),
|
||||
]
|
||||
@ -154,23 +155,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
|
||||
jtu.format_shape_dtype_string(shapes, dtype),
|
||||
axis, keepdims, return_sign, use_b),
|
||||
# TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
|
||||
"shapes": shapes, "dtype": dtype,
|
||||
"axis": axis, "keepdims": keepdims,
|
||||
"return_sign": return_sign, "use_b": use_b}
|
||||
for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes
|
||||
@jtu.sample_product(
|
||||
[dict(shapes=shapes, axis=axis, use_b=use_b)
|
||||
for shape_group in compatible_shapes
|
||||
for use_b in [False, True]
|
||||
for shapes in itertools.product(*(
|
||||
(shape_group, shape_group) if use_b else (shape_group,)))
|
||||
for axis in range(-max(len(shape) for shape in shapes),
|
||||
max(len(shape) for shape in shapes))
|
||||
for keepdims in [False, True]
|
||||
for return_sign in [False, True]))
|
||||
],
|
||||
dtype=float_dtypes + complex_dtypes + int_dtypes,
|
||||
keepdims=[False, True],
|
||||
return_sign=[False, True],
|
||||
)
|
||||
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*")
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testLogSumExp(self, shapes, dtype, axis,
|
||||
@ -235,23 +232,23 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
result = lsp_special.logsumexp(1.0, b=1.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.test_name, shapes, dtypes),
|
||||
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
|
||||
"test_autodiff": rec.test_autodiff,
|
||||
"nondiff_argnums": rec.nondiff_argnums,
|
||||
"scipy_op": getattr(osp_special, rec.name),
|
||||
"lax_op": getattr(lsp_special, rec.name)}
|
||||
for shapes in itertools.combinations_with_replacement(all_shapes, rec.nargs)
|
||||
for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
|
||||
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)))
|
||||
for rec in JAX_SPECIAL_FUNCTION_RECORDS))
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(op=rec.name, rng_factory=rec.rng_factory,
|
||||
test_autodiff=rec.test_autodiff,
|
||||
nondiff_argnums=rec.nondiff_argnums)],
|
||||
shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs),
|
||||
dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
|
||||
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)),
|
||||
)
|
||||
for rec in JAX_SPECIAL_FUNCTION_RECORDS
|
||||
))
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes,
|
||||
def testScipySpecialFun(self, op, rng_factory, shapes, dtypes,
|
||||
test_autodiff, nondiff_argnums):
|
||||
scipy_op = getattr(osp_special, op)
|
||||
lax_op = getattr(lsp_special, op)
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
||||
args = args_maker()
|
||||
@ -272,13 +269,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
atol=jtu.if_device_under_test("tpu", .1, 1e-3),
|
||||
rtol=.1, eps=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_d={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), d),
|
||||
"shape": shape, "dtype": dtype, "d": d}
|
||||
for shape in all_shapes
|
||||
for dtype in float_dtypes
|
||||
for d in [1, 2, 5]))
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=float_dtypes,
|
||||
d=[1, 2, 5],
|
||||
)
|
||||
@jax.numpy_rank_promotion('raise')
|
||||
def testMultigammaln(self, shape, dtype, d):
|
||||
def scipy_fun(a):
|
||||
@ -322,13 +317,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
|
||||
self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_lmax={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), l_max),
|
||||
"l_max": l_max, "shape": shape, "dtype": dtype}
|
||||
for l_max in [1, 2, 3, 6]
|
||||
for shape in [(5,), (10,)]
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(
|
||||
l_max=[1, 2, 3, 6],
|
||||
shape=[(5,), (10,)],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testLpmn(self, l_max, shape, dtype):
|
||||
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -344,13 +337,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
atol=3e-3, check_dtypes=False)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_lmax={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), l_max),
|
||||
"l_max": l_max, "shape": shape, "dtype": dtype}
|
||||
for l_max in [3, 4, 6, 32]
|
||||
for shape in [(2,), (3,), (4,), (64,)]
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(
|
||||
l_max=[3, 4, 6, 32],
|
||||
shape=[(2,), (3,), (4,), (64,)],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testNormalizedLpmnValues(self, l_max, shape, dtype):
|
||||
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -432,11 +423,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name': f'_maxdegree={l_max}_inputsize={num_z}_dtype={dtype.__name__}',
|
||||
'l_max': l_max, 'num_z': num_z, 'dtype': dtype}
|
||||
@jtu.sample_product(
|
||||
[dict(l_max=l_max, num_z=num_z)
|
||||
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
|
||||
for dtype in jtu.dtypes.all_integer))
|
||||
],
|
||||
dtype=jtu.dtypes.all_integer,
|
||||
)
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
|
||||
"""Tests against JIT compatibility and Numpy."""
|
||||
@ -475,32 +467,18 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name':
|
||||
'_shape={}'
|
||||
'_n_zero_sv={}_degeneracy={}_geometric_spectrum={}'
|
||||
'_max_sv={}_method={}_side={}'
|
||||
'_nonzero_condition_number={}_seed={}'.format(
|
||||
jtu.format_shape_dtype_string(
|
||||
shape, jnp.dtype(dtype).name).replace(" ", ""),
|
||||
n_zero_sv, degeneracy, geometric_spectrum, max_sv,
|
||||
method, side, nonzero_condition_number, seed
|
||||
),
|
||||
'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy,
|
||||
'geometric_spectrum': geometric_spectrum,
|
||||
'max_sv': max_sv, 'shape': shape, 'method': method,
|
||||
'side': side, 'nonzero_condition_number': nonzero_condition_number,
|
||||
'dtype': dtype, 'seed': seed}
|
||||
for n_zero_sv in n_zero_svs
|
||||
for degeneracy in degeneracies
|
||||
for geometric_spectrum in geometric_spectra
|
||||
for max_sv in max_svs
|
||||
for shape in polar_shapes
|
||||
for method in methods
|
||||
for side in sides
|
||||
for nonzero_condition_number in nonzero_condition_numbers
|
||||
for dtype in jtu.dtypes.inexact
|
||||
for seed in seeds))
|
||||
@jtu.sample_product(
|
||||
n_zero_sv=n_zero_svs,
|
||||
degeneracy=degeneracies,
|
||||
geometric_spectrum=geometric_spectra,
|
||||
max_sv=max_svs,
|
||||
shape=polar_shapes,
|
||||
method=methods,
|
||||
side=sides,
|
||||
nonzero_condition_number=nonzero_condition_numbers,
|
||||
dtype=jtu.dtypes.inexact,
|
||||
seed=seeds,
|
||||
)
|
||||
def testPolar(
|
||||
self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method,
|
||||
side, nonzero_condition_number, dtype, seed):
|
||||
@ -552,15 +530,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(
|
||||
matrix, recon, atol=tol * jnp.linalg.norm(matrix))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name':
|
||||
'_linear_size={}_dtype={}_termination_size={}'.format(
|
||||
linear_size, jnp.dtype(dtype).name, termination_size),
|
||||
'linear_size': linear_size, 'dtype': dtype,
|
||||
'termination_size': termination_size}
|
||||
for linear_size in linear_sizes
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for termination_size in [1, 19]))
|
||||
@jtu.sample_product(
|
||||
linear_size=linear_sizes,
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
termination_size=[1, 19],
|
||||
)
|
||||
def test_spectral_dac_eigh(self, linear_size, dtype, termination_size):
|
||||
if jtu.device_under_test() != "tpu" and termination_size != 1:
|
||||
raise unittest.SkipTest(
|
||||
@ -584,13 +558,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
HV, vV, atol=atol * (80 if jnp.issubdtype(dtype, jnp.complexfloating)
|
||||
else 30))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{jtu.format_shape_dtype_string((n_obs, n_codes, *n_feats), dtype)}",
|
||||
"n_obs": n_obs, "n_codes": n_codes, "n_feats": n_feats, "dtype": dtype}
|
||||
for n_obs in [1, 3, 5]
|
||||
for n_codes in [1, 2, 4]
|
||||
for n_feats in [()] + [(i,) for i in range(1, 3)]
|
||||
for dtype in float_dtypes + int_dtypes)) # scipy doesn't support complex
|
||||
@jtu.sample_product(
|
||||
n_obs=[1, 3, 5],
|
||||
n_codes=[1, 2, 4],
|
||||
n_feats=[()] + [(i,) for i in range(1, 3)],
|
||||
dtype=float_dtypes + int_dtypes, # scipy doesn't support complex
|
||||
)
|
||||
def test_vq(self, n_obs, n_codes, n_feats, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng((n_obs, *n_feats), dtype), rng((n_codes, *n_feats), dtype)]
|
||||
|
@ -16,7 +16,6 @@
|
||||
from functools import partial
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
import numpy.random as npr
|
||||
@ -35,12 +34,7 @@ npr.seed(0)
|
||||
class MultiBackendTest(jtu.JaxTestCase):
|
||||
"""Tests jit targeting to different backends."""
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_backend={backend}",
|
||||
"backend": backend,
|
||||
}
|
||||
for backend in ['cpu', 'gpu', 'tpu', None]
|
||||
))
|
||||
@jtu.sample_product(backend=['cpu', 'gpu', 'tpu', None])
|
||||
def testMultiBackend(self, backend):
|
||||
if backend not in ('cpu', jtu.device_under_test(), None):
|
||||
raise SkipTest("Backend is not CPU or the device under test")
|
||||
@ -56,10 +50,9 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
correct_platform = backend if backend else jtu.device_under_test()
|
||||
self.assertEqual(z.device().platform, correct_platform)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_ordering={ordering}",
|
||||
"ordering": ordering,}
|
||||
for ordering in [('cpu', None), ('gpu', None), ('tpu', None), (None, None)]))
|
||||
@jtu.sample_product(
|
||||
ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)]
|
||||
)
|
||||
def testMultiBackendNestedJit(self, ordering):
|
||||
outer, inner = ordering
|
||||
if outer not in ('cpu', jtu.device_under_test(), None):
|
||||
@ -81,14 +74,11 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
correct_platform = outer if outer else jtu.device_under_test()
|
||||
self.assertEqual(z.device().platform, correct_platform)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_ordering={ordering}",
|
||||
"ordering": ordering,}
|
||||
for ordering in [
|
||||
('cpu', 'gpu'), ('gpu', 'cpu'),
|
||||
('cpu', 'tpu'), ('tpu', 'cpu'),
|
||||
(None, 'cpu'), (None, 'gpu'), (None, 'tpu'),
|
||||
]))
|
||||
@jtu.sample_product(
|
||||
ordering=[('cpu', 'gpu'), ('gpu', 'cpu'), ('cpu', 'tpu'), ('tpu', 'cpu'),
|
||||
(None, 'cpu'), (None, 'gpu'), (None, 'tpu'),
|
||||
],
|
||||
)
|
||||
def testMultiBackendNestedJitConflict(self, ordering):
|
||||
outer, inner = ordering
|
||||
if outer not in ('cpu', jtu.device_under_test(), None):
|
||||
@ -111,11 +101,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
y = npr.uniform(size=(10, 10))
|
||||
self.assertRaises(ValueError, lambda: fun(x, y))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_backend={backend}",
|
||||
"backend": backend,}
|
||||
for backend in ['cpu', 'gpu', 'tpu']
|
||||
))
|
||||
@jtu.sample_product(backend=['cpu', 'gpu', 'tpu'])
|
||||
def testGpuMultiBackendOpByOpReturn(self, backend):
|
||||
if backend not in ('cpu', jtu.device_under_test()):
|
||||
raise SkipTest("Backend is not CPU or the device under test")
|
||||
|
@ -262,9 +262,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
|
||||
self.assertArraysEqual(bits64, expected64)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_" + name, "prng_name": name}
|
||||
for name, _ in PRNG_IMPLS))
|
||||
@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS])
|
||||
def testRngRandomBitsShapeDtype(self, prng_name):
|
||||
# Like testRngRandomBits, but only meant to exercise random_bits
|
||||
# on every PRNG implementation. Instead of values, only checks
|
||||
@ -313,9 +311,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
assert np.all(rand_bits_32 == rand_bits_32[0])
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": case._testname(), "case": case}
|
||||
for case in _RANDOM_VALUES_CASES))
|
||||
@jtu.sample_product(case=_RANDOM_VALUES_CASES)
|
||||
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
|
||||
def testRandomDistributionValues(self, case):
|
||||
"""
|
||||
@ -366,8 +362,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
_prng_key_as_array(random.fold_in(k, 4)),
|
||||
np.array([2285895361, 433833334], dtype='uint32'))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "seed={seed}_type={type}_jit={jit}".format(**dct), **dct} for dct in [
|
||||
@jtu.sample_product([
|
||||
{"seed": 0, "type": int, "jit": True, "key": [0, 0]},
|
||||
{"seed": 0, "type": int, "jit": False, "key": [0, 0]},
|
||||
{"seed": 1, "type": np.int32, "jit": True, "key": [0, 1]},
|
||||
@ -390,8 +385,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
]
|
||||
))
|
||||
])
|
||||
def test_prng_seeds_and_keys(self, seed, type, jit, key):
|
||||
if (jit and type is int and not config.x64_enabled and
|
||||
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
|
||||
@ -525,9 +519,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def seed_prng(self, seed):
|
||||
return random.threefry2x32_key(seed)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(dtype=jtu.dtypes.floating)
|
||||
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
|
||||
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
|
||||
numpy_bits = np.array(1., dtype).view(bits_dtype)
|
||||
@ -535,9 +527,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))()
|
||||
self.assertEqual(numpy_bits, xla_bits)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testRngUniform(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.uniform(key, (10000,), dtype)
|
||||
@ -550,9 +540,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in int_dtypes + uint_dtypes))
|
||||
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
|
||||
def testRngRandint(self, dtype):
|
||||
lo = 5
|
||||
hi = 10
|
||||
@ -568,9 +556,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertTrue(np.all(lo <= samples))
|
||||
self.assertTrue(np.all(samples < hi))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testNormal(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.normal(key, (10000,), dtype)
|
||||
@ -589,9 +575,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
res_bfloat16 = random.normal(self.seed_prng(0), dtype=jnp.bfloat16)
|
||||
self.assertAllClose(res_bfloat16, res_bfloat16_str)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in complex_dtypes))
|
||||
@jtu.sample_product(dtype=complex_dtypes)
|
||||
def testNormalComplex(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.normal(key, (10000,), dtype)
|
||||
@ -605,9 +589,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
|
||||
self.assertEqual(dtype, samples.dtype)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testTruncatedNormal(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
|
||||
@ -623,9 +605,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
|
||||
@jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer)
|
||||
def testShuffle(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
x = np.arange(100).astype(dtype)
|
||||
@ -641,23 +621,21 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
|
||||
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=
|
||||
f"_{np.dtype(dtype).name}_input_range_or_shape={input_range_or_shape}"
|
||||
f"_shape={shape}_replace={replace}_weighted={weighted}_axis={axis}",
|
||||
dtype=dtype, input_range_or_shape=input_range_or_shape,
|
||||
shape=shape, replace=replace, weighted=weighted, axis=axis)
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
|
||||
for shape in [(), (5,), (4, 5)]
|
||||
for replace in [True, False]
|
||||
for weighted in [True, False]
|
||||
for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)]
|
||||
for is_range in [type(input_range_or_shape) is int]
|
||||
for ndim in [1 if is_range else len(input_range_or_shape)]
|
||||
for axis in range(-ndim, ndim or 1)
|
||||
for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]]
|
||||
if replace or np.prod(shape) <= ninputs))
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, replace=replace, axis=axis,
|
||||
input_range_or_shape=input_range_or_shape)
|
||||
for shape in [(), (5,), (4, 5)]
|
||||
for replace in [True, False]
|
||||
for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)]
|
||||
for is_range in [type(input_range_or_shape) is int]
|
||||
for ndim in [1 if is_range else len(input_range_or_shape)]
|
||||
for axis in range(-ndim, ndim or 1)
|
||||
for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]]
|
||||
if replace or np.prod(shape) <= ninputs
|
||||
],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
|
||||
weighted=[True, False],
|
||||
)
|
||||
def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis):
|
||||
# This is the function API that we test against (note that self.rng().choice differs)
|
||||
np_choice = np.random.default_rng(0).choice
|
||||
@ -690,17 +668,16 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(sample, jax.jit(rand, static_argnames=
|
||||
'x' if is_range else None)(key, x))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=f"_dtype={dtype}_range_or_shape={range_or_shape}"
|
||||
f"_axis={axis}_independent={independent}",
|
||||
dtype=dtype, range_or_shape=range_or_shape, axis=axis, independent=independent)
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
|
||||
for range_or_shape in [0, 1, 100, (0,), (1,), (100,),
|
||||
(10, 10), (10, 5, 2), (0, 5), (1, 5)]
|
||||
for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)]
|
||||
for axis in range(-ndim, ndim or 1)
|
||||
for independent in [True, False]))
|
||||
@jtu.sample_product(
|
||||
[dict(range_or_shape=range_or_shape, axis=axis)
|
||||
for range_or_shape in [0, 1, 100, (0,), (1,), (100,),
|
||||
(10, 10), (10, 5, 2), (0, 5), (1, 5)]
|
||||
for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)]
|
||||
for axis in range(-ndim, ndim or 1)
|
||||
],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
|
||||
independent=[True, False],
|
||||
)
|
||||
def testPermutation(self, dtype, range_or_shape, axis, independent):
|
||||
key = self.seed_prng(0)
|
||||
is_range = type(range_or_shape) is int
|
||||
@ -739,11 +716,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(core.ConcretizationTypeError):
|
||||
jax.jit(random.permutation)(key, 10)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_p={p}_dtype={np.dtype(dtype).name}",
|
||||
"p": p, "dtype": dtype}
|
||||
for p in [0.1, 0.5, 0.9]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
p=[0.1, 0.5, 0.9],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testBernoulli(self, p, dtype):
|
||||
key = self.seed_prng(0)
|
||||
p = np.array(p, dtype=dtype)
|
||||
@ -756,17 +732,18 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_p={p}_{np.dtype(dtype).name}_{sample_shape}",
|
||||
"p": p, "axis": axis, "dtype": dtype, 'sample_shape': sample_shape}
|
||||
for (p, axis) in [
|
||||
@jtu.sample_product(
|
||||
[dict(p=p, axis=axis)
|
||||
for (p, axis) in [
|
||||
([.25] * 4, -1),
|
||||
([.1, .2, .3, .4], -1),
|
||||
([[.5, .5], [.1, .9]], 1),
|
||||
([[.5, .1], [.5, .9]], 0),
|
||||
]
|
||||
for sample_shape in [(10000,), (5000, 2)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
]
|
||||
],
|
||||
sample_shape=[(10000,), (5000, 2)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testCategorical(self, p, axis, dtype, sample_shape):
|
||||
key = self.seed_prng(0)
|
||||
p = np.array(p, dtype=dtype)
|
||||
@ -800,12 +777,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_a={a}_b={b}_dtype={np.dtype(dtype).name}",
|
||||
"a": a, "b": b, "dtype": dtype}
|
||||
for a in [0.2, 5.]
|
||||
for b in [0.2, 5.]
|
||||
for dtype in [np.float64])) # NOTE: KS test fails with float32
|
||||
@jtu.sample_product(
|
||||
a=[0.2, 5.],
|
||||
b=[0.2, 5.],
|
||||
dtype=[np.float64], # NOTE: KS test fails with float32
|
||||
)
|
||||
def testBeta(self, a, b, dtype):
|
||||
if not config.x64_enabled:
|
||||
raise SkipTest("skip test except on X64")
|
||||
@ -832,9 +808,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
ones = samples[samples >= 0.5]
|
||||
self.assertAllClose(ones, jnp.ones_like(ones))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testCauchy(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.cauchy(key, (10000,), dtype)
|
||||
@ -846,13 +820,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_alpha={alpha}_dtype={np.dtype(dtype).name}",
|
||||
"alpha": alpha, "dtype": dtype}
|
||||
for alpha in [
|
||||
np.array([0.2, 1., 5.]),
|
||||
]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
alpha=[np.array([0.2, 1., 5.]),],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
|
||||
def testDirichlet(self, alpha, dtype):
|
||||
key = self.seed_prng(0)
|
||||
@ -883,9 +854,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(samples.max(1), jnp.ones(samples.shape[0]),
|
||||
check_dtypes=False, rtol=1E-5)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testExponential(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.exponential(key, (10000,), dtype)
|
||||
@ -897,11 +866,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_a={a}_dtype={np.dtype(dtype).name}",
|
||||
"a": a, "dtype": dtype}
|
||||
for a in [0.1, 1., 10.]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
a=[0.1, 1., 10.],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testGammaVsLogGamma(self, a, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
@ -911,11 +879,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)))
|
||||
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_a={a}_dtype={np.dtype(dtype).name}",
|
||||
"a": a, "dtype": dtype}
|
||||
for a in [0.1, 1., 10.]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
a=[0.1, 1., 10.],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testGamma(self, a, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
@ -932,11 +899,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_a={alpha}_logspace={log_space}",
|
||||
"alpha": alpha, "log_space": log_space}
|
||||
for log_space in [True, False]
|
||||
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
|
||||
@jtu.sample_product(
|
||||
log_space=[True, False],
|
||||
alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4],
|
||||
)
|
||||
def testGammaGrad(self, log_space, alpha):
|
||||
rng = self.seed_prng(0)
|
||||
alphas = np.full((100,), alpha)
|
||||
@ -969,11 +935,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# Should not crash with a type error.
|
||||
jax.vjp(f, a, b)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_lam={lam}_dtype={np.dtype(dtype).name}",
|
||||
"lam": lam, "dtype": np.dtype(dtype)}
|
||||
for lam in [0.5, 3, 9, 11, 50, 500]
|
||||
for dtype in [np.int16, np.int32, np.int64]))
|
||||
@jtu.sample_product(
|
||||
lam=[0.5, 3, 9, 11, 50, 500],
|
||||
dtype=[np.int16, np.int32, np.int64],
|
||||
)
|
||||
def testPoisson(self, lam, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
|
||||
@ -1019,9 +984,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
samples = random.poisson(key, lam, shape=(3,))
|
||||
self.assertArraysEqual(samples, jnp.array([-1, 0, -1]))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(dtype=jtu.dtypes.floating)
|
||||
def testGumbel(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.gumbel(key, (10000,), dtype)
|
||||
@ -1033,9 +996,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testLaplace(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.laplace(key, (10000,), dtype)
|
||||
@ -1047,9 +1008,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype).name}", "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testLogistic(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.logistic(key, (10000,), dtype)
|
||||
@ -1061,15 +1020,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_n={}_shape={}"\
|
||||
.format(n, jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"n": n,
|
||||
"shape": shape,
|
||||
"dtype": dtype}
|
||||
for n in range(1, 5)
|
||||
for shape in [(), (5,), (10, 5)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
@jtu.sample_product(
|
||||
n=range(1, 5),
|
||||
shape=[(), (5,), (10, 5)],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
)
|
||||
def testOrthogonal(self, n, shape, dtype):
|
||||
key = self.seed_prng(0)
|
||||
q = random.orthogonal(key, n, shape, dtype)
|
||||
@ -1083,15 +1038,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
atol=tol, rtol=tol,
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_p={}_shape={}"\
|
||||
.format(p, jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"p": p,
|
||||
"shape": shape,
|
||||
"dtype": dtype}
|
||||
for p in [.5, 1., 1.5, 2., 2.5]
|
||||
for shape in [(), (5,), (10, 5)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
p=[.5, 1., 1.5, 2., 2.5],
|
||||
shape=[(), (5,), (10, 5)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testGeneralizedNormal(self, p, shape, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
|
||||
@ -1103,17 +1054,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertEqual(samples.dtype, dtype)
|
||||
self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_d={}_p={}_shape={}"\
|
||||
.format(d, p, jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"d": d,
|
||||
"p": p,
|
||||
"shape": shape,
|
||||
"dtype": dtype}
|
||||
for d in range(1, 5)
|
||||
for p in [.5, 1., 1.5, 2., 2.5]
|
||||
for shape in [(), (5,), (10, 5)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
d=range(1, 5),
|
||||
p=[.5, 1., 1.5, 2., 2.5],
|
||||
shape=[(), (5,), (10, 5)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testBall(self, d, p, shape, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
|
||||
@ -1127,11 +1073,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p)
|
||||
self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_b={b}_dtype={np.dtype(dtype).name}",
|
||||
"b": b, "dtype": dtype}
|
||||
for b in [0.1, 1., 10.]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
b=[0.1, 1., 10.],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testPareto(self, b, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
|
||||
@ -1149,11 +1094,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_df={df}_dtype={np.dtype(dtype).name}",
|
||||
"df": df, "dtype": dtype}
|
||||
for df in [0.1, 1., 10.]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
df=[0.1, 1., 10.],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
|
||||
def testT(self, df, dtype):
|
||||
key = self.seed_prng(0)
|
||||
@ -1166,13 +1110,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dim={}_dtype={}_method={}".format(
|
||||
dim, np.dtype(dtype), method),
|
||||
"dim": dim, "dtype": dtype, "method": method}
|
||||
for dim in [1, 3, 5]
|
||||
for dtype in float_dtypes
|
||||
for method in ['svd', 'eigh', 'cholesky']))
|
||||
@jtu.sample_product(
|
||||
dim=[1, 3, 5],
|
||||
dtype=float_dtypes,
|
||||
method=['svd', 'eigh', 'cholesky'],
|
||||
)
|
||||
def testMultivariateNormal(self, dim, dtype, method):
|
||||
r = self.rng()
|
||||
mean = r.randn(dim)
|
||||
@ -1198,18 +1140,13 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# eigenvectors follow a standard normal distribution.
|
||||
self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dim={}_mean_batch_size={}_cov_batch_size={}_shape={}_method={}"\
|
||||
.format(dim, mean_batch_size, cov_batch_size, shape, method),
|
||||
"dim": dim,
|
||||
"mean_batch_size": mean_batch_size,
|
||||
"cov_batch_size": cov_batch_size,
|
||||
"shape": shape, "method": method}
|
||||
for dim in [1, 2, 4]
|
||||
for mean_batch_size in [(), (3,), (2, 3)]
|
||||
for cov_batch_size in [(), (3,), (2, 3)]
|
||||
for shape in [(), (1,), (5,)]
|
||||
for method in ['cholesky', 'svd', 'eigh']))
|
||||
@jtu.sample_product(
|
||||
dim=[1, 2, 4],
|
||||
mean_batch_size=[(), (3,), (2, 3)],
|
||||
cov_batch_size=[(), (3,), (2, 3)],
|
||||
shape=[(), (1,), (5,)],
|
||||
method=['cholesky', 'svd', 'eigh'],
|
||||
)
|
||||
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
|
||||
shape, method):
|
||||
r = self.rng()
|
||||
@ -1426,10 +1363,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
with jax.enable_checks(False): # check_jaxpr will materialize array
|
||||
jax.eval_shape(f, 0) # doesn't error
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_seed={seed}_type={type_}", "seed": seed, "type_": type_}
|
||||
for type_ in ["int", "np.array", "jnp.array"]
|
||||
for seed in [-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)]))
|
||||
@jtu.sample_product(
|
||||
type_=["int", "np.array", "jnp.array"],
|
||||
seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)],
|
||||
)
|
||||
def test_prng_jit_invariance(self, seed, type_):
|
||||
if type_ == "int" and seed == (1 << 64) - 1:
|
||||
self.skipTest("Expected failure: Python int too large.")
|
||||
@ -1453,9 +1390,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
jax.jit(random.split)(key)
|
||||
self.assertLessEqual(count[0], 1) # 1 for the argument device_put
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={dtype}", "dtype": dtype}
|
||||
for dtype in int_dtypes + uint_dtypes))
|
||||
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
|
||||
def test_randint_bounds(self, dtype):
|
||||
min = np.iinfo(dtype).min
|
||||
max = np.iinfo(dtype).max
|
||||
@ -1541,10 +1476,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(keys, random.KeyArray)
|
||||
self.assertEqual(keys.shape, (3,))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_internal" if use_internal else "",
|
||||
"use_internal": use_internal}
|
||||
for use_internal in [False, True]))
|
||||
@jtu.sample_product(use_internal=[False, True])
|
||||
def test_random_unwrap(self, use_internal):
|
||||
unwrap = prng_internal.random_unwrap if use_internal else random.key_data
|
||||
def f(k): return unwrap(k)
|
||||
|
@ -16,7 +16,7 @@
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
from absl.testing import absltest
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal as osp_signal
|
||||
@ -70,22 +70,19 @@ def _complex_dtype(dtype):
|
||||
class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed scipy.stats implementations"""
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format(
|
||||
op,
|
||||
jtu.format_shape_dtype_string(xshape, dtype),
|
||||
jtu.format_shape_dtype_string(yshape, dtype),
|
||||
mode),
|
||||
"xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode,
|
||||
"jsp_op": getattr(jsp_signal, op),
|
||||
"osp_op": getattr(osp_signal, op)}
|
||||
for mode in ['full', 'same', 'valid']
|
||||
for op in ['convolve', 'correlate']
|
||||
for dtype in default_dtypes
|
||||
for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
|
||||
for xshape in shapeset
|
||||
for yshape in shapeset))
|
||||
def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
|
||||
@jtu.sample_product(
|
||||
[dict(xshape=xshape, yshape=yshape)
|
||||
for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
|
||||
for xshape in shapeset
|
||||
for yshape in shapeset
|
||||
],
|
||||
mode=['full', 'same', 'valid'],
|
||||
op=['convolve', 'correlate'],
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
def testConvolutions(self, xshape, yshape, dtype, mode, op):
|
||||
jsp_op = getattr(jsp_signal, op)
|
||||
osp_op = getattr(osp_signal, op)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
||||
osp_fun = partial(osp_op, mode=mode)
|
||||
@ -94,21 +91,16 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "op={}_xshape={}_yshape={}_mode={}".format(
|
||||
op,
|
||||
jtu.format_shape_dtype_string(xshape, dtype),
|
||||
jtu.format_shape_dtype_string(yshape, dtype),
|
||||
mode),
|
||||
"xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode,
|
||||
"jsp_op": getattr(jsp_signal, op),
|
||||
"osp_op": getattr(osp_signal, op)}
|
||||
for mode in ['full', 'same', 'valid']
|
||||
for op in ['convolve2d', 'correlate2d']
|
||||
for dtype in default_dtypes
|
||||
for xshape in twodim_shapes
|
||||
for yshape in twodim_shapes))
|
||||
def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
|
||||
@jtu.sample_product(
|
||||
mode=['full', 'same', 'valid'],
|
||||
op=['convolve2d', 'correlate2d'],
|
||||
dtype=default_dtypes,
|
||||
xshape=twodim_shapes,
|
||||
yshape=twodim_shapes,
|
||||
)
|
||||
def testConvolutions2D(self, xshape, yshape, dtype, mode, op):
|
||||
jsp_op = getattr(jsp_signal, op)
|
||||
osp_op = getattr(osp_signal, op)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
||||
osp_fun = partial(osp_op, mode=mode)
|
||||
@ -118,15 +110,13 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}_type={}_bp={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
|
||||
"shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp}
|
||||
for shape in [(5,), (4, 5), (3, 4, 5)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
|
||||
for axis in [0, -1]
|
||||
for type in ['constant', 'linear']
|
||||
for bp in [0, [0, 2]]))
|
||||
@jtu.sample_product(
|
||||
shape=[(5,), (4, 5), (3, 4, 5)],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
|
||||
axis=[0, -1],
|
||||
type=['constant', 'linear'],
|
||||
bp=[0, [0, 2]],
|
||||
)
|
||||
def testDetrend(self, shape, dtype, axis, type, bp):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -144,24 +134,19 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}"
|
||||
f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}"
|
||||
f"_axis={timeaxis}_nfft={nfft}",
|
||||
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
|
||||
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
|
||||
"detrend": detrend, "boundary": boundary, "padded": padded,
|
||||
"timeaxis": timeaxis}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis,
|
||||
nfft=nfft)
|
||||
for shape, nperseg, noverlap, timeaxis in stft_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for detrend in ['constant', 'linear', False]
|
||||
for boundary in [None, 'even', 'odd', 'zeros']
|
||||
for padded in [True, False]))
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
fs=[1.0, 16000.0],
|
||||
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
|
||||
detrend=['constant', 'linear', False],
|
||||
boundary=[None, 'even', 'odd', 'zeros'],
|
||||
padded=[True, False],
|
||||
)
|
||||
def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
|
||||
noverlap, nfft, detrend, boundary, padded,
|
||||
timeaxis):
|
||||
@ -193,26 +178,19 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
# Tests with `average == 'median'`` is excluded from `testCsd*`
|
||||
# due to the issue:
|
||||
# https://github.com/scipy/scipy/issues/15601
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}"
|
||||
f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}"
|
||||
f"_average={average}_scaling={scaling}_nfft={nfft}"
|
||||
f"_fs={fs}_window={window}_detrend={detrend}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}"
|
||||
f"_axis={timeaxis}",
|
||||
"xshape": xshape, "yshape": yshape, "dtype": dtype, "fs": fs,
|
||||
"window": window, "nperseg": nperseg, "noverlap": noverlap,
|
||||
"nfft": nfft, "detrend": detrend, "scaling": scaling,
|
||||
"timeaxis": timeaxis, "average": average}
|
||||
@jtu.sample_product(
|
||||
[dict(xshape=xshape, yshape=yshape, nperseg=nperseg, noverlap=noverlap,
|
||||
timeaxis=timeaxis, nfft=nfft)
|
||||
for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for detrend in ['constant', 'linear', False]
|
||||
for scaling in ['density', 'spectrum']
|
||||
for average in ['mean']))
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
fs=[1.0, 16000.0],
|
||||
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
|
||||
detrend=['constant', 'linear', False],
|
||||
scaling=['density', 'spectrum'],
|
||||
average=['mean'],
|
||||
)
|
||||
def testCsdAgainstNumpy(
|
||||
self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft,
|
||||
detrend, scaling, timeaxis, average):
|
||||
@ -242,25 +220,19 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_average={average}_scaling={scaling}_nfft={nfft}"
|
||||
f"_fs={fs}_window={window}_detrend={detrend}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}"
|
||||
f"_axis={timeaxis}",
|
||||
"shape": shape, "dtype": dtype, "fs": fs,
|
||||
"window": window, "nperseg": nperseg, "noverlap": noverlap,
|
||||
"nfft": nfft, "detrend": detrend, "scaling": scaling,
|
||||
"timeaxis": timeaxis, "average": average}
|
||||
for shape, unused_yshape, nperseg, noverlap, timeaxis in csd_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis,
|
||||
nfft=nfft)
|
||||
for shape, _yshape, nperseg, noverlap, timeaxis in csd_test_shapes
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for detrend in ['constant', 'linear', False]
|
||||
for scaling in ['density', 'spectrum']
|
||||
for average in ['mean']))
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
fs=[1.0, 16000.0],
|
||||
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
|
||||
detrend=['constant', 'linear', False],
|
||||
scaling=['density', 'spectrum'],
|
||||
average=['mean'],
|
||||
)
|
||||
def testCsdWithSameParamAgainstNumpy(
|
||||
self, *, shape, dtype, fs, window, nperseg, noverlap, nfft,
|
||||
detrend, scaling, timeaxis, average):
|
||||
@ -292,26 +264,20 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_fs={fs}_window={window}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}"
|
||||
f"_detrend={detrend}_return_onesided={return_onesided}"
|
||||
f"_scaling={scaling}_axis={timeaxis}_average={average}",
|
||||
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
|
||||
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
|
||||
"detrend": detrend, "return_onesided": return_onesided,
|
||||
"scaling": scaling, "timeaxis": timeaxis, "average": average}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis,
|
||||
nfft=nfft)
|
||||
for shape, nperseg, noverlap, timeaxis in welch_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for detrend in ['constant', 'linear', False]
|
||||
for return_onesided in [True, False]
|
||||
for scaling in ['density', 'spectrum']
|
||||
for average in ['mean', 'median']))
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
fs=[1.0, 16000.0],
|
||||
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
|
||||
detrend=['constant', 'linear', False],
|
||||
return_onesided=[True, False],
|
||||
scaling=['density', 'spectrum'],
|
||||
average=['mean', 'median'],
|
||||
)
|
||||
def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
|
||||
noverlap, nfft, detrend, return_onesided,
|
||||
scaling, timeaxis, average):
|
||||
@ -342,20 +308,14 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}"
|
||||
f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}"
|
||||
f"_axis={timeaxis}",
|
||||
"shape": shape, "dtype": dtype,
|
||||
"nperseg": nperseg, "noverlap": noverlap,
|
||||
"use_nperseg": use_nperseg, "use_noverlap": use_noverlap,
|
||||
"timeaxis": timeaxis}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis)
|
||||
for shape, nperseg, noverlap, timeaxis in welch_test_shapes
|
||||
for use_nperseg in [False, True]
|
||||
for use_noverlap in [False, True]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
|
||||
],
|
||||
use_nperseg=[False, True],
|
||||
use_noverlap=[False, True],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
|
||||
)
|
||||
def testWelchWithDefaultStepArgsAgainstNumpy(
|
||||
self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap,
|
||||
timeaxis):
|
||||
@ -386,23 +346,18 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_fs={fs}_window={window}_boundary={boundary}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}_onesided={onesided}"
|
||||
f"_timeaxis={timeaxis}_freqaxis{freqaxis}_nfft={nfft}",
|
||||
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
|
||||
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
|
||||
"onesided": onesided, "boundary": boundary,
|
||||
"timeaxis": timeaxis, "freqaxis": freqaxis}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, nperseg=nperseg, noverlap=noverlap, timeaxis=timeaxis,
|
||||
freqaxis=freqaxis, nfft=nfft)
|
||||
for shape, nperseg, noverlap, timeaxis, freqaxis in istft_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for onesided in [False, True]
|
||||
for boundary in [False, True]))
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
fs=[1.0, 16000.0],
|
||||
window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'],
|
||||
onesided=[False, True],
|
||||
boundary=[False, True],
|
||||
)
|
||||
def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
|
||||
noverlap, nfft, onesided, boundary,
|
||||
timeaxis, freqaxis):
|
||||
|
@ -16,7 +16,7 @@
|
||||
from functools import partial
|
||||
import itertools
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
from absl.testing import absltest
|
||||
|
||||
import numpy as np
|
||||
import scipy.stats as osp_stats
|
||||
@ -34,12 +34,10 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
|
||||
|
||||
|
||||
def genNamedParametersNArgs(n):
|
||||
return parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
|
||||
"shapes": shapes, "dtypes": dtypes}
|
||||
for shapes in itertools.combinations_with_replacement(all_shapes, n)
|
||||
for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
|
||||
return jtu.sample_product(
|
||||
shapes=itertools.combinations_with_replacement(all_shapes, n),
|
||||
dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, n),
|
||||
)
|
||||
|
||||
|
||||
# Allow implicit rank promotion in these tests, as virtually every test exercises it.
|
||||
@ -179,14 +177,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes),
|
||||
"shapes": [x_shape, alpha_shape], "dtypes": dtypes}
|
||||
@jtu.sample_product(
|
||||
shapes=[
|
||||
[x_shape, alpha_shape]
|
||||
for x_shape in one_and_two_dim_shapes
|
||||
for alpha_shape in [(x_shape[0],), (x_shape[0] + 1,)]
|
||||
for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, 2)
|
||||
))
|
||||
],
|
||||
dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, 2),
|
||||
)
|
||||
def testDirichletLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
|
||||
@ -564,13 +562,12 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)),
|
||||
check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape, shape], [x_dtype, p_dtype]),
|
||||
"x_dtype": x_dtype,
|
||||
"p_dtype": p_dtype,
|
||||
"shape": shape}
|
||||
for shape in [(2), (4,), (1, 5)]
|
||||
for (x_dtype, p_dtype) in itertools.product(jtu.dtypes.integer, jtu.dtypes.floating)))
|
||||
@jtu.sample_product(
|
||||
[dict(x_dtype=x_dtype, p_dtype=p_dtype)
|
||||
for x_dtype, p_dtype in itertools.product(jtu.dtypes.integer, jtu.dtypes.floating)
|
||||
],
|
||||
shape=[(2), (4,), (1, 5)],
|
||||
)
|
||||
def testMultinomialLogPmf(self, shape, x_dtype, p_dtype):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.multinomial.logpmf
|
||||
@ -588,16 +585,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_x={}_mean={}_cov={}".format(
|
||||
jtu.format_shape_dtype_string(x_shape, x_dtype),
|
||||
jtu.format_shape_dtype_string(mean_shape, mean_dtype)
|
||||
if mean_shape is not None else None,
|
||||
jtu.format_shape_dtype_string(cov_shape, cov_dtype)
|
||||
if cov_shape is not None else None),
|
||||
"x_shape": x_shape, "x_dtype": x_dtype,
|
||||
"mean_shape": mean_shape, "mean_dtype": mean_dtype,
|
||||
"cov_shape": cov_shape, "cov_dtype": cov_dtype}
|
||||
@jtu.sample_product(
|
||||
[dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape)
|
||||
for x_shape, mean_shape, cov_shape in [
|
||||
# # These test cases cover default values for mean/cov, but we don't
|
||||
# # support those yet (and they seem not very valuable).
|
||||
@ -616,9 +605,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
[(3, 4), (4,), (4, 4)],
|
||||
[(2, 3, 4), (4,), (4, 4)],
|
||||
]
|
||||
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
|
||||
if (mean_shape is not None or mean_dtype == np.float32)
|
||||
and (cov_shape is not None or cov_dtype == np.float32)))
|
||||
],
|
||||
[dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype)
|
||||
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
|
||||
],
|
||||
# if (mean_shape is not None or mean_dtype == np.float32)
|
||||
# and (cov_shape is not None or cov_dtype == np.float32)))
|
||||
)
|
||||
def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
|
||||
mean_dtype, cov_shape, cov_dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -642,16 +635,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_x={}_mean={}_cov={}".format(
|
||||
jtu.format_shape_dtype_string(x_shape, x_dtype),
|
||||
jtu.format_shape_dtype_string(mean_shape, mean_dtype)
|
||||
if mean_shape is not None else None,
|
||||
jtu.format_shape_dtype_string(cov_shape, cov_dtype)
|
||||
if cov_shape is not None else None),
|
||||
"x_shape": x_shape, "x_dtype": x_dtype,
|
||||
"mean_shape": mean_shape, "mean_dtype": mean_dtype,
|
||||
"cov_shape": cov_shape, "cov_dtype": cov_dtype}
|
||||
@jtu.sample_product(
|
||||
[dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape)
|
||||
for x_shape, mean_shape, cov_shape in [
|
||||
# These test cases are where scipy flattens things, which has
|
||||
# different batch semantics than some might expect, so we manually
|
||||
@ -663,9 +648,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
[(1, 3, 2), (3, 2,), (5, 1, 2, 2)],
|
||||
[(5, 3, 2), (1, 2,), (2, 2)],
|
||||
]
|
||||
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
|
||||
if (mean_shape is not None or mean_dtype == np.float32)
|
||||
and (cov_shape is not None or cov_dtype == np.float32)))
|
||||
],
|
||||
[dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype)
|
||||
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
|
||||
],
|
||||
)
|
||||
def testMultivariateNormalLogpdfBroadcasted(self, x_shape, x_dtype, mean_shape,
|
||||
mean_dtype, cov_shape, cov_dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -691,12 +678,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_ndim={ndim}_nbatch={nbatch}_dtype={dtype.__name__}",
|
||||
"ndim": ndim, "nbatch": nbatch, "dtype": dtype}
|
||||
for ndim in [2, 3]
|
||||
for nbatch in [1, 3, 5]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
ndim=[2, 3],
|
||||
nbatch=[1, 3, 5],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype):
|
||||
# Regression test for #5570
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -709,23 +695,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
|
||||
self.assertArraysEqual(result1, result2, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_inshape={}_outsize={}_weights={}_method={}_func={}".format(
|
||||
jtu.format_shape_dtype_string(inshape, dtype),
|
||||
outsize, weights, method, func),
|
||||
"dtype": dtype,
|
||||
"inshape": inshape,
|
||||
"outsize": outsize,
|
||||
"weights": weights,
|
||||
"method": method,
|
||||
"func": func}
|
||||
for inshape in [(50,), (3, 50), (2, 12)]
|
||||
for dtype in jtu.dtypes.floating
|
||||
for outsize in [None, 10]
|
||||
for weights in [False, True]
|
||||
for method in [None, "scott", "silverman", 1.5, "callable"]
|
||||
for func in [None, "evaluate", "logpdf", "pdf"]))
|
||||
@jtu.sample_product(
|
||||
inshape=[(50,), (3, 50), (2, 12)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
outsize=[None, 10],
|
||||
weights=[False, True],
|
||||
method=[None, "scott", "silverman", 1.5, "callable"],
|
||||
func=[None, "evaluate", "logpdf", "pdf"],
|
||||
)
|
||||
def testKde(self, inshape, dtype, outsize, weights, method, func):
|
||||
if method == "callable":
|
||||
method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))
|
||||
@ -764,12 +741,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (3, 15), (1, 12)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
shape=[(15,), (3, 15), (1, 12)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testKdeIntegrateGaussian(self, shape, dtype):
|
||||
def scipy_fun(dataset, weights):
|
||||
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
|
||||
@ -796,12 +771,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (12,)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
shape=[(15,), (12,)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testKdeIntegrateBox1d(self, shape, dtype):
|
||||
def scipy_fun(dataset, weights):
|
||||
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
|
||||
@ -820,12 +793,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (3, 15), (1, 12)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
shape=[(15,), (3, 15), (1, 12)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testKdeIntegrateKde(self, shape, dtype):
|
||||
def scipy_fun(dataset, weights):
|
||||
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
|
||||
@ -848,12 +819,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (3, 15), (1, 12)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
shape=[(15,), (3, 15), (1, 12)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testKdeResampleShape(self, shape, dtype):
|
||||
def resample(key, dataset, weights, *, shape):
|
||||
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
|
||||
@ -878,12 +847,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
result = func(*args)
|
||||
assert result.shape == (ndim, 4)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (1, 12)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.sample_product(
|
||||
shape=[(15,), (1, 12)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testKdeResample1d(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
dataset = rng(shape, dtype)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -205,18 +205,15 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(out.todense(), x.todense() + y.todense())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}_unique_indices={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense,
|
||||
unique_indices),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense,
|
||||
"unique_indices": unique_indices}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in (jtu.dtypes.integer + jtu.dtypes.floating +
|
||||
jtu.dtypes.complex)
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)
|
||||
for unique_indices in [True, False]))
|
||||
],
|
||||
dtype=jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
unique_indices=[True, False],
|
||||
)
|
||||
def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices):
|
||||
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
|
||||
x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch,
|
||||
@ -281,10 +278,8 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
res_sparse = res_sparse.todense()
|
||||
self.assertArraysAllClose(res_dense, res_sparse)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dimensions={}_nbatch={}_ndense={}".format(
|
||||
jtu.format_shape_dtype_string(shape, np.float32), dimensions, n_batch, n_dense),
|
||||
"shape": shape, "dimensions": dimensions, "n_batch": n_batch, "n_dense": n_dense}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, dimensions=dimensions, n_batch=n_batch, n_dense=n_dense)
|
||||
for shape, dimensions in [
|
||||
[(1,), (0,)],
|
||||
[(1,), (-1,)],
|
||||
@ -294,7 +289,9 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
[(2, 1, 3, 1), (3,)],
|
||||
]
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) - n_batch + 1)))
|
||||
for n_dense in range(len(shape) - n_batch + 1)
|
||||
],
|
||||
)
|
||||
def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
@ -307,9 +304,8 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(result_sparse, result_dense)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_shapes={shapes}_func={func}_nbatch={n_batch}",
|
||||
"shapes": shapes, "func": func, "n_batch": n_batch}
|
||||
@jtu.sample_product(
|
||||
[dict(shapes=shapes, func=func, n_batch=n_batch)
|
||||
for shapes, func, n_batch in [
|
||||
([(4,), (4,)], "concatenate", 0),
|
||||
([(4,), (4,)], "stack", 0),
|
||||
@ -328,7 +324,9 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
([(2, 4), (2, 5)], "hstack", 2),
|
||||
([(2, 4), (4,), (3, 4)], "vstack", 0),
|
||||
([(1, 4), (4,), (1, 4)], "vstack", 0),
|
||||
]))
|
||||
]
|
||||
],
|
||||
)
|
||||
def testSparseConcatenate(self, shapes, func, n_batch):
|
||||
f = self.sparsify(getattr(jnp, func))
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
@ -336,9 +334,8 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
|
||||
self.assertArraysEqual(f(arrs), f(sparrs).todense())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}",
|
||||
"shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, new_shape=new_shape, n_batch=n_batch, n_dense=n_dense)
|
||||
for shape, new_shape, n_batch, n_dense in [
|
||||
[(6,), (2, 3), 0, 0],
|
||||
[(1, 4), (2, 2), 0, 0],
|
||||
@ -348,7 +345,9 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
[(2, 3, 4), (3, 8), 0, 0],
|
||||
[(2, 3, 4), (1, 2, 12), 1, 0],
|
||||
[(2, 3, 4), (6, 2, 2), 2, 0],
|
||||
]))
|
||||
]
|
||||
],
|
||||
)
|
||||
def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
arr = rng(shape, 'int32')
|
||||
@ -359,10 +358,9 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertArraysEqual(arr2, arr2_sparse.todense())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}_dimensions={dimensions}",
|
||||
"shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense,
|
||||
"dimensions": dimensions}
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, new_shape=new_shape, n_batch=n_batch, n_dense=n_dense,
|
||||
dimensions=dimensions)
|
||||
for shape, new_shape, n_batch, n_dense, dimensions in [
|
||||
[(2, 3, 4), (24,), 0, 0, None],
|
||||
[(2, 3, 4), (24,), 0, 0, (0, 1, 2)],
|
||||
@ -375,7 +373,9 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
[(4, 2, 3), (2, 2, 6), 1, 0, (0, 2, 1)],
|
||||
[(2, 3, 4), (6, 4), 2, 0, (0, 1, 2)],
|
||||
[(2, 3, 4), (6, 4), 2, 0, (1, 0, 2)],
|
||||
]))
|
||||
]
|
||||
],
|
||||
)
|
||||
def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch, n_dense, dimensions):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
arr = rng(shape, 'int32')
|
||||
|
@ -15,7 +15,6 @@
|
||||
"""Tests for Stax library."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -50,105 +49,81 @@ def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
||||
@jtu.with_config(jax_numpy_rank_promotion="allow")
|
||||
class StaxTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_shape={shape}", "shape": shape}
|
||||
for shape in [(2, 3), (5,)]))
|
||||
@jtu.sample_product(shape=[(2, 3), (5,)])
|
||||
def testRandnInitShape(self, shape):
|
||||
key = random.PRNGKey(0)
|
||||
out = stax.randn()(key, shape)
|
||||
self.assertEqual(out.shape, shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_shape={shape}", "shape": shape}
|
||||
for shape in [(2, 3), (2, 3, 4)]))
|
||||
@jtu.sample_product(shape=[(2, 3), (2, 3, 4)])
|
||||
def testGlorotInitShape(self, shape):
|
||||
key = random.PRNGKey(0)
|
||||
out = stax.glorot()(key, shape)
|
||||
self.assertEqual(out.shape, shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
|
||||
.format(channels, filter_shape, padding, strides, input_shape),
|
||||
"channels": channels, "filter_shape": filter_shape, "padding": padding,
|
||||
"strides": strides, "input_shape": input_shape}
|
||||
for channels in [2, 3]
|
||||
for filter_shape in [(1, 1), (2, 3)]
|
||||
for padding in ["SAME", "VALID"]
|
||||
for strides in [None, (2, 1)]
|
||||
for input_shape in [(2, 10, 11, 1)]))
|
||||
@jtu.sample_product(
|
||||
channels=[2, 3],
|
||||
filter_shape=[(1, 1), (2, 3)],
|
||||
padding=["SAME", "VALID"],
|
||||
strides=[None, (2, 1)],
|
||||
input_shape=[(2, 10, 11, 1)],
|
||||
)
|
||||
def testConvShape(self, channels, filter_shape, padding, strides,
|
||||
input_shape):
|
||||
init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides,
|
||||
padding=padding)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
|
||||
.format(channels, filter_shape, padding, strides, input_shape),
|
||||
"channels": channels, "filter_shape": filter_shape, "padding": padding,
|
||||
"strides": strides, "input_shape": input_shape}
|
||||
for channels in [2, 3]
|
||||
for filter_shape in [(1, 1), (2, 3), (3, 3)]
|
||||
for padding in ["SAME", "VALID"]
|
||||
for strides in [None, (2, 1), (2, 2)]
|
||||
for input_shape in [(2, 10, 11, 1)]))
|
||||
@jtu.sample_product(
|
||||
channels=[2, 3],
|
||||
filter_shape=[(1, 1), (2, 3), (3, 3)],
|
||||
padding=["SAME", "VALID"],
|
||||
strides=[None, (2, 1), (2, 2)],
|
||||
input_shape=[(2, 10, 11, 1)],
|
||||
)
|
||||
def testConvTransposeShape(self, channels, filter_shape, padding, strides,
|
||||
input_shape):
|
||||
init_fun, apply_fun = stax.ConvTranspose(channels, filter_shape, # 2D
|
||||
strides=strides, padding=padding)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
|
||||
.format(channels, filter_shape, padding, strides, input_shape),
|
||||
"channels": channels, "filter_shape": filter_shape, "padding": padding,
|
||||
"strides": strides, "input_shape": input_shape}
|
||||
for channels in [2, 3]
|
||||
for filter_shape in [(1,), (2,), (3,)]
|
||||
for padding in ["SAME", "VALID"]
|
||||
for strides in [None, (1,), (2,)]
|
||||
for input_shape in [(2, 10, 1)]))
|
||||
|
||||
@jtu.sample_product(
|
||||
channels=[2, 3],
|
||||
filter_shape=[(1,), (2,), (3,)],
|
||||
padding=["SAME", "VALID"],
|
||||
strides=[None, (1,), (2,)],
|
||||
input_shape=[(2, 10, 1)],
|
||||
)
|
||||
def testConv1DTransposeShape(self, channels, filter_shape, padding, strides,
|
||||
input_shape):
|
||||
init_fun, apply_fun = stax.Conv1DTranspose(channels, filter_shape,
|
||||
strides=strides, padding=padding)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_out_dim={}_input_shape={}"
|
||||
.format(out_dim, input_shape),
|
||||
"out_dim": out_dim, "input_shape": input_shape}
|
||||
for out_dim in [3, 4]
|
||||
for input_shape in [(2, 3), (3, 4)]))
|
||||
@jtu.sample_product(
|
||||
out_dim=[3, 4],
|
||||
input_shape=[(2, 3), (3, 4)],
|
||||
)
|
||||
def testDenseShape(self, out_dim, input_shape):
|
||||
init_fun, apply_fun = stax.Dense(out_dim)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_input_shape={}_nonlinear={}"
|
||||
.format(input_shape, nonlinear),
|
||||
"input_shape": input_shape, "nonlinear": nonlinear}
|
||||
for input_shape in [(2, 3), (2, 3, 4)]
|
||||
for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"]))
|
||||
@jtu.sample_product(
|
||||
input_shape=[(2, 3), (2, 3, 4)],
|
||||
nonlinear=["Relu", "Sigmoid", "Elu", "LeakyRelu"],
|
||||
)
|
||||
def testNonlinearShape(self, input_shape, nonlinear):
|
||||
init_fun, apply_fun = getattr(stax, nonlinear)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}"
|
||||
"_maxpool={}_spec={}"
|
||||
.format(window_shape, padding, strides, input_shape,
|
||||
max_pool, spec),
|
||||
"window_shape": window_shape, "padding": padding, "strides": strides,
|
||||
"input_shape": input_shape, "max_pool": max_pool, "spec": spec}
|
||||
for window_shape in [(1, 1), (2, 3)]
|
||||
for padding in ["VALID"]
|
||||
for strides in [None, (2, 1)]
|
||||
for input_shape in [(2, 5, 6, 4)]
|
||||
for max_pool in [False, True]
|
||||
for spec in ["NHWC", "NCHW", "WHNC", "WHCN"]))
|
||||
@jtu.sample_product(
|
||||
window_shape=[(1, 1), (2, 3)],
|
||||
padding=["VALID"],
|
||||
strides=[None, (2, 1)],
|
||||
input_shape=[(2, 5, 6, 4)],
|
||||
max_pool=[False, True],
|
||||
spec=["NHWC", "NCHW", "WHNC", "WHCN"],
|
||||
)
|
||||
def testPoolingShape(self, window_shape, padding, strides, input_shape,
|
||||
max_pool, spec):
|
||||
layer = stax.MaxPool if max_pool else stax.AvgPool
|
||||
@ -156,49 +131,41 @@ class StaxTest(jtu.JaxTestCase):
|
||||
spec=spec)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_shape={input_shape}",
|
||||
"input_shape": input_shape}
|
||||
for input_shape in [(2, 3), (2, 3, 4)]))
|
||||
@jtu.sample_product(input_shape=[(2, 3), (2, 3, 4)])
|
||||
def testFlattenShape(self, input_shape):
|
||||
init_fun, apply_fun = stax.Flatten
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_input_shape={input_shape}_spec={i}",
|
||||
"input_shape": input_shape, "spec": spec}
|
||||
for input_shape in [(2, 5, 6, 1)]
|
||||
for i, spec in enumerate([
|
||||
[stax.Conv(3, (2, 2))],
|
||||
[stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]])))
|
||||
@jtu.sample_product(
|
||||
input_shape=[(2, 5, 6, 1)],
|
||||
spec=[
|
||||
[stax.Conv(3, (2, 2))],
|
||||
[stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)],
|
||||
],
|
||||
)
|
||||
def testSerialComposeLayersShape(self, input_shape, spec):
|
||||
init_fun, apply_fun = stax.serial(*spec)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_input_shape={input_shape}",
|
||||
"input_shape": input_shape}
|
||||
for input_shape in [(3, 4), (2, 5, 6, 1)]))
|
||||
@jtu.sample_product(input_shape=[(3, 4), (2, 5, 6, 1)])
|
||||
def testDropoutShape(self, input_shape):
|
||||
init_fun, apply_fun = stax.Dropout(0.9)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_input_shape={input_shape}",
|
||||
"input_shape": input_shape}
|
||||
for input_shape in [(3, 4), (2, 5, 6, 1)]))
|
||||
@jtu.sample_product(input_shape=[(3, 4), (2, 5, 6, 1)])
|
||||
def testFanInSum(self, input_shape):
|
||||
init_fun, apply_fun = stax.FanInSum
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape])
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_inshapes={input_shapes}_axis={axis}",
|
||||
"input_shapes": input_shapes, "axis": axis}
|
||||
@jtu.sample_product(
|
||||
[dict(input_shapes=input_shapes, axis=axis)
|
||||
for input_shapes, axis in [
|
||||
([(2, 3), (2, 1)], 1),
|
||||
([(2, 3), (2, 1)], -1),
|
||||
([(1, 2, 4), (1, 1, 4)], 1),
|
||||
]))
|
||||
]
|
||||
],
|
||||
)
|
||||
def testFanInConcat(self, input_shapes, axis):
|
||||
init_fun, apply_fun = stax.FanInConcat(axis)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
|
||||
|
@ -46,7 +46,7 @@ from jax._src.nn import initializers as nn_initializers
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2, prod, safe_zip
|
||||
from jax._src.util import unzip2, prod, safe_zip
|
||||
from jax._src.lax import parallel as lax_parallel
|
||||
from jax._src.lax.parallel import pgather
|
||||
from jax.interpreters import batching, pxla
|
||||
@ -876,13 +876,12 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
|
||||
|
||||
class NamedNumPyTest(XMapTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{reduction.__name__}_axes={axes}_i={mapped_axis}",
|
||||
"reduction": reduction, "axes": axes, "mapped_axis": mapped_axis}
|
||||
for reduction in (jnp.sum, jnp.max, jnp.min, jnp.mean, jnp.var, jnp.std,
|
||||
jscipy.special.logsumexp)
|
||||
for axes in (0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0))
|
||||
for mapped_axis in range(3)))
|
||||
@jtu.sample_product(
|
||||
reduction=(jnp.sum, jnp.max, jnp.min, jnp.mean, jnp.var, jnp.std,
|
||||
jscipy.special.logsumexp),
|
||||
axes=(0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0)),
|
||||
mapped_axis=range(3),
|
||||
)
|
||||
def testReductions(self, reduction, axes, mapped_axis):
|
||||
axes_t = axes if isinstance(axes, tuple) else (axes,)
|
||||
ref_red = partial(reduction,
|
||||
@ -901,25 +900,15 @@ class NamedNumPyTest(XMapTestCase):
|
||||
|
||||
class NamedRandomTest(XMapTestCase):
|
||||
|
||||
@curry
|
||||
def parameterize_by_sampler(extra, f, subset):
|
||||
if extra is None:
|
||||
extra = [("", {})]
|
||||
else:
|
||||
extra = list(extra)
|
||||
subset_fn = jtu.cases_from_list if subset else lambda x: x
|
||||
return parameterized.named_parameters(subset_fn(
|
||||
{"testcase_name": name + extra_name, "distr_sample": sample, **extra_kwargs}
|
||||
for name, sample in [
|
||||
("Uniform", jax.random.uniform),
|
||||
("Normal", jax.random.normal),
|
||||
("Bernoulli", partial(jax.random.bernoulli, p=0.5)),
|
||||
("TruncatedNormal", partial(jax.random.truncated_normal, lower=-2, upper=2)),
|
||||
]
|
||||
for extra_name, extra_kwargs in extra))(f)
|
||||
SAMPLERS = [
|
||||
("Uniform", jax.random.uniform),
|
||||
("Normal", jax.random.normal),
|
||||
("Bernoulli", partial(jax.random.bernoulli, p=0.5)),
|
||||
("TruncatedNormal", partial(jax.random.truncated_normal, lower=-2, upper=2)),
|
||||
]
|
||||
|
||||
@parameterize_by_sampler(None, subset=False)
|
||||
def testSamplerSharding(self, distr_sample):
|
||||
@parameterized.parameters(*SAMPLERS)
|
||||
def testSamplerSharding(self, distr_name, distr_sample):
|
||||
def sample(shape, map_size):
|
||||
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=shape),
|
||||
in_axes=(), out_axes=[None, 'i', ...], axis_sizes={'i': map_size})()
|
||||
@ -931,12 +920,15 @@ class NamedRandomTest(XMapTestCase):
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
sample(NamedShape(3, i=4), 5)
|
||||
|
||||
@parameterize_by_sampler(
|
||||
((f"_mesh={mesh}_resources={sorted(axis_resources.items())}",
|
||||
{"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)})
|
||||
for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True)
|
||||
@jtu.sample_product(
|
||||
[dict(distr_name=name, distr_sample=sample)
|
||||
for name, sample in SAMPLERS],
|
||||
[dict(axis_resources=tuple(axis_resources.items()), mesh=tuple(mesh))
|
||||
for axis_resources, mesh in schedules({'i': 4, 'j': 6})],
|
||||
)
|
||||
@jtu.with_mesh_from_kwargs
|
||||
def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh):
|
||||
def testSamplerResourceIndependence(self, distr_name, distr_sample,
|
||||
axis_resources, mesh):
|
||||
def sample(axis_resources):
|
||||
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)),
|
||||
in_axes=(), out_axes=['i', 'j', ...], axis_sizes={'i': 4, 'j': 6},
|
||||
@ -965,13 +957,12 @@ class NamedNNTest(XMapTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "to match the size of axis i, but 3 != 5"):
|
||||
f(jnp.ones(5, dtype='int32'))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_map_in={map_in}_map_out={map_out}_fan={fan}_distr={distr}",
|
||||
"map_in": map_in, "map_out": map_out, "fan": fan,
|
||||
"distr": distr}
|
||||
for map_in, map_out in [(True, False), (False, True), (True, True)]
|
||||
for fan in ['fan_in', 'fan_out', 'fan_avg']
|
||||
for distr in ['uniform', 'normal', 'truncated_normal']))
|
||||
@jtu.sample_product(
|
||||
[dict(map_in=map_in, map_out=map_out)
|
||||
for map_in, map_out in [(True, False), (False, True), (True, True)]],
|
||||
fan=['fan_in', 'fan_out', 'fan_avg'],
|
||||
distr=['uniform', 'normal', 'truncated_normal'],
|
||||
)
|
||||
def testVarianceScaling(self, map_in, map_out, fan, distr):
|
||||
shape = (80, 50, 7)
|
||||
fan_in, fan_out = nn_initializers._compute_fans(NamedShape(*shape), 0, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user