From fffdb2daa8b22ba5bb7adf742008c1bce0d945b6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 1 Jun 2020 17:19:23 -0400 Subject: [PATCH] =?UTF-8?q?Make=20check=5Fdtypes,=20atol,=20and=20rtol=20k?= =?UTF-8?q?eyword-only=20arguments=20in=20jax.test=5F=E2=80=A6=20(#3280)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs. Default to check_dtypes=True. Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense. No functional changes intended. * Fix a number of lax reference implementations to preserve types. --- jax/lax_reference.py | 25 +- jax/test_util.py | 26 +- tests/api_test.py | 97 +++---- tests/array_interoperability_test.py | 8 +- tests/batching_test.py | 83 +++--- tests/callback_test.py | 8 +- tests/fft_test.py | 14 +- tests/host_callback_test.py | 10 +- tests/infeed_test.py | 6 +- tests/jet_test.py | 6 +- tests/lax_control_flow_test.py | 24 +- tests/lax_numpy_einsum_test.py | 2 +- tests/lax_numpy_indexing_test.py | 34 +-- tests/lax_numpy_test.py | 364 +++++++++++++-------------- tests/lax_numpy_vectorize_test.py | 8 +- tests/lax_scipy_sparse_test.py | 9 +- tests/lax_scipy_test.py | 12 +- tests/lax_test.py | 93 ++++--- tests/linalg_test.py | 121 ++++----- tests/loops_test.py | 56 ++--- tests/masking_test.py | 4 +- tests/multibackend_test.py | 4 +- tests/nn_test.py | 14 +- tests/optimizers_test.py | 4 +- tests/parallel_test.py | 8 +- tests/pmap_test.py | 31 ++- tests/random_test.py | 33 ++- tests/scipy_ndimage_test.py | 8 +- tests/scipy_signal_test.py | 4 +- tests/scipy_stats_test.py | 52 ++-- tests/vectorize_test.py | 6 +- 31 files changed, 559 insertions(+), 615 deletions(-) diff --git a/jax/lax_reference.py b/jax/lax_reference.py index e7e301f13..f3240e830 100644 --- a/jax/lax_reference.py +++ b/jax/lax_reference.py @@ -32,7 +32,7 @@ neg = np.negative sign = np.sign floor = np.floor ceil = np.ceil -round = lambda x: np.trunc(x + np.copysign(.5, x)) +round = lambda x: np.trunc(x + np.copysign(.5, x)).astype(x.dtype) nextafter = np.nextafter is_finite = np.isfinite @@ -47,7 +47,7 @@ cos = np.cos atan2 = np.arctan2 sqrt = np.sqrt -rsqrt = lambda x: 1. / np.sqrt(x) +rsqrt = lambda x: np.ones_like(x) / np.sqrt(x) square = np.square reciprocal = np.reciprocal tan = np.tan @@ -60,16 +60,17 @@ asinh = np.arcsinh acosh = np.arccosh atanh = np.arctanh -betainc = scipy.special.betainc -lgamma = scipy.special.gammaln -digamma = scipy.special.digamma +def betainc(a, b, x): return scipy.special.betainc(a, b, x).astype(x.dtype) +def lgamma(x): return scipy.special.gammaln(x).astype(x.dtype) +def digamma(x): return scipy.special.digamma(x).astype(x.dtype) igamma = scipy.special.gammainc igammac = scipy.special.gammaincc -erf = scipy.special.erf -erfc = scipy.special.erfc -erf_inv = scipy.special.erfinv -bessel_i0e = scipy.special.i0e -bessel_i1e = scipy.special.i1e +def erf(x): return scipy.special.erf(x).astype(x.dtype) +def erfc(x): return scipy.special.erfc(x).astype(x.dtype) +def erf_inv(x): return scipy.special.erfinv(x).astype(x.dtype) + +def bessel_i0e(x): return scipy.special.i0e(x).astype(x.dtype) +def bessel_i1e(x): return scipy.special.i1e(x).astype(x.dtype) real = np.real imag = np.imag @@ -150,7 +151,7 @@ def bitcast_convert_type(operand, dtype): return np.asarray(operand).view(dtype) def clamp(min, operand, max): - return np.clip(operand, np.clip(min, None, max), max) + return np.clip(operand, np.clip(min, None, max), max).astype(operand.dtype) def concatenate(operands, dimension): return np.concatenate(operands, axis=dimension) @@ -295,8 +296,6 @@ def sort_key_val(keys, values, dimension=-1): idxs[dimension] = np.argsort(keys, axis=dimension) return keys[tuple(idxs)], values[tuple(idxs)] -# TODO untake - ### conv util def _conv(lhs, rhs, window_strides, pads): diff --git a/jax/test_util.py b/jax/test_util.py index c8b85a341..6f96262b9 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -719,13 +719,14 @@ class JaxTestCase(parameterized.TestCase): def rng(self): return self._rng - def assertArraysEqual(self, x, y, check_dtypes): + def assertArraysEqual(self, x, y, *, check_dtypes=True): """Assert that x and y arrays are exactly equal.""" if check_dtypes: self.assertDtypesMatch(x, y) np.testing.assert_equal(x, y) - def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None): + def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None, + rtol=None): """Assert that x and y are close (up to numerical tolerances).""" self.assertEqual(x.shape, y.shape) atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) @@ -740,18 +741,20 @@ class JaxTestCase(parameterized.TestCase): if FLAGS.jax_enable_x64: self.assertEqual(_dtype(x), _dtype(y)) - def assertAllClose(self, x, y, check_dtypes, atol=None, rtol=None): + def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None): """Assert that x and y, either arrays or nested tuples/lists, are close.""" if isinstance(x, dict): self.assertIsInstance(y, dict) self.assertEqual(set(x.keys()), set(y.keys())) for k in x.keys(): - self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol) + self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol, + rtol=rtol) elif is_sequence(x) and not hasattr(x, '__array__'): self.assertTrue(is_sequence(y) and not hasattr(y, '__array__')) self.assertEqual(len(x), len(y)) for x_elt, y_elt in zip(x, y): - self.assertAllClose(x_elt, y_elt, check_dtypes, atol=atol, rtol=rtol) + self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol, + rtol=rtol) elif hasattr(x, '__array__') or np.isscalar(x): self.assertTrue(hasattr(y, '__array__') or np.isscalar(y)) if check_dtypes: @@ -772,7 +775,7 @@ class JaxTestCase(parameterized.TestCase): self.assertMultiLineEqual(expected_clean, what_clean, msg="Found\n{}\nExpecting\n{}".format(what, expected)) - def _CompileAndCheck(self, fun, args_maker, check_dtypes, + def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, rtol=None, atol=None): """Helper method for running JAX compilation and allclose assertions.""" args = args_maker() @@ -802,8 +805,10 @@ class JaxTestCase(parameterized.TestCase): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol) - self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) + self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes, + atol=atol, rtol=rtol) + self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + atol=atol, rtol=rtol) args = args_maker() @@ -813,10 +818,11 @@ class JaxTestCase(parameterized.TestCase): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) + self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + atol=atol, rtol=rtol) def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, - check_dtypes=False, tol=None): + check_dtypes=True, tol=None): args = args_maker() lax_ans = lax_op(*args) numpy_ans = numpy_reference_op(*args) diff --git a/tests/api_test.py b/tests/api_test.py index 63c73461f..bb301cacb 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -462,7 +462,7 @@ class APITest(jtu.JaxTestCase): def test_grad_and_aux_basic(self): g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.) - self.assertAllClose(g, grad(lambda x: x**3)(3.), check_dtypes=True) + self.assertAllClose(g, grad(lambda x: x**3)(3.)) self.assertAllClose(aux, [9.], check_dtypes=False) def test_grad_and_aux_nested(self): @@ -557,7 +557,7 @@ class APITest(jtu.JaxTestCase): res2 = api.jit(inner)(5.) return res1 + res2 - self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)), check_dtypes=True) + self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,))) def test_complex_grad_raises_error(self): @@ -596,7 +596,7 @@ class APITest(jtu.JaxTestCase): ans = jacrev(f)(zs) expected = grad(f)(zs) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def test_complex_input_jacfwd_raises_error(self): self.assertRaises(TypeError, lambda: jacfwd(lambda x: jnp.sin(x))(1 + 2j)) @@ -980,7 +980,7 @@ class APITest(jtu.JaxTestCase): futures = [executor.submit(partial(f, x)) for x in xs] ys = [f.result() for f in futures] for x, y in zip(xs, ys): - self.assertAllClose(x, y, check_dtypes=True) + self.assertAllClose(x, y) def test_concurrent_jit(self): @jit @@ -992,7 +992,7 @@ class APITest(jtu.JaxTestCase): futures = [executor.submit(partial(f, x)) for x in xs] ys = [f.result() for f in futures] for x, y in zip(xs, ys): - self.assertAllClose(x * 2 - 3., y, check_dtypes=True) + self.assertAllClose(x * 2 - 3., y) def test_dtype_warning(self): # cf. issue #1230 @@ -1054,7 +1054,7 @@ class APITest(jtu.JaxTestCase): out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y) out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y) - self.assertAllClose(out1, out2, check_dtypes=True) + self.assertAllClose(out1, out2) def test_vmap_in_axes_tree_prefix_error(self): # https://github.com/google/jax/issues/795 @@ -2007,11 +2007,10 @@ class CustomJVPTest(jtu.JaxTestCase): f.defjvp(f_jvp) x = 3. - self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True) + self.assertAllClose(f(x), jnp.sin(x)) self.assertAllClose(api.jvp(f, (x,), (1.,)), - (jnp.sin(x), 2 * jnp.cos(x)), - check_dtypes=True) - self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x), check_dtypes=True) + (jnp.sin(x), 2 * jnp.cos(x))) + self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) def test_invariance(self): @api.custom_jvp @@ -2052,8 +2051,8 @@ class CustomJVPTest(jtu.JaxTestCase): return f(x), 3 * g f.defjvp(f_jvp) x = 2. - self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True) - self.assertAllClose(f(-x), jnp.cos(-x), check_dtypes=True) + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(f(-x), jnp.cos(-x)) self.assertAllClose(api.jvp(f, (x,), (1.,)), (jnp.sin(x), 2.), check_dtypes=False) @@ -2079,29 +2078,24 @@ class CustomJVPTest(jtu.JaxTestCase): xx = jnp.arange(6.).reshape(2, 3) # vmap of f - self.assertAllClose(api.vmap(f)(x), jnp.sin(x), check_dtypes=True) - self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx), check_dtypes=True) + self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) + self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) # vmap of jvp of f self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x), - (jnp.sin(x), 2 * jnp.cos(x) * x), - check_dtypes=True) + (jnp.sin(x), 2 * jnp.cos(x) * x)) self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx), - check_dtypes=True) + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) # jvp of vmap of f self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)), - (jnp.sin(x), 2 * jnp.cos(x) * x), - check_dtypes=True) + (jnp.sin(x), 2 * jnp.cos(x) * x)) self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx), - check_dtypes=True) + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) # vmap of jvp of vmap of f self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx), - check_dtypes=True) + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) def test_jit(self): @api.custom_jvp @@ -2116,8 +2110,8 @@ class CustomJVPTest(jtu.JaxTestCase): x = 3. # jit - self.assertAllClose(api.jit(f)(x), jnp.sin(x), check_dtypes=True) - self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x), check_dtypes=True) + self.assertAllClose(api.jit(f)(x), jnp.sin(x)) + self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) # jit of jvp self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x), @@ -2139,7 +2133,7 @@ class CustomJVPTest(jtu.JaxTestCase): return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']} f.defjvp(f_jvp) x = {'a': 3.} - self.assertAllClose(f(x)['b'], jnp.sin(x['a']), check_dtypes=True) + self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) self.assertAllClose(api.jvp(f, (x,), (x,)), ({'b': jnp.sin(x['a'])}, {'b': 2 * jnp.cos(x['a']) * x['a']}), @@ -2491,11 +2485,10 @@ class CustomVJPTest(jtu.JaxTestCase): f.defvjp(f_fwd, f_rev) x = 3. - self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True) - self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x), check_dtypes=True) + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) self.assertAllClose(api.value_and_grad(f)(x), - (jnp.sin(x), 2 * jnp.cos(x)), - check_dtypes=True) + (jnp.sin(x), 2 * jnp.cos(x))) def test_invariance(self): @api.custom_vjp @@ -2539,8 +2532,8 @@ class CustomVJPTest(jtu.JaxTestCase): return (3 * g,) f.defvjp(f_fwd, f_rev) x = 2. - self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True) - self.assertAllClose(f(-x), jnp.cos(-x), check_dtypes=True) + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(f(-x), jnp.cos(-x)) self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.), check_dtypes=False) self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.), @@ -2562,33 +2555,26 @@ class CustomVJPTest(jtu.JaxTestCase): xx = jnp.arange(6.).reshape(2, 3) # vmap of f - self.assertAllClose(api.vmap(f)(x), jnp.sin(x), check_dtypes=True) - self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx), check_dtypes=True) + self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) + self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) # vmap of grad of f - self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x), - check_dtypes=True) + self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x)) self.assertAllClose(api.vmap(api.value_and_grad(f))(x), - (jnp.sin(x), 2 * jnp.cos(x)), - check_dtypes=True) - self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx), - check_dtypes=True) + (jnp.sin(x), 2 * jnp.cos(x))) + self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx)) self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx)), - check_dtypes=True) + (jnp.sin(xx), 2 * jnp.cos(xx))) # grad of vmap of f self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x), - 2 * jnp.cos(x), - check_dtypes=True) + 2 * jnp.cos(x)) self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx), - 2 * jnp.cos(xx), - check_dtypes=True) + 2 * jnp.cos(xx)) # vmap of grad of vmap of f self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx), - 2 * jnp.cos(xx), - check_dtypes=True) + 2 * jnp.cos(xx)) def test_jit(self): @api.custom_vjp @@ -2603,8 +2589,8 @@ class CustomVJPTest(jtu.JaxTestCase): x = 3. # jit - self.assertAllClose(api.jit(f)(x), jnp.sin(x), check_dtypes=True) - self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x), check_dtypes=True) + self.assertAllClose(api.jit(f)(x), jnp.sin(x)) + self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) # jit of grad self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x), @@ -2625,10 +2611,9 @@ class CustomVJPTest(jtu.JaxTestCase): return ({'a': 2 * cos_x * g['b']},) f.defvjp(f_fwd, f_bwd) x = {'a': 3.} - self.assertAllClose(f(x)['b'], jnp.sin(x['a']), check_dtypes=True) + self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) self.assertAllClose(api.grad(lambda x: f(x)['b'])(x), - {'a': 2 * jnp.cos(x['a'])}, - check_dtypes=True) + {'a': 2 * jnp.cos(x['a'])}) def test_jvp_error(self): @api.custom_vjp @@ -2680,7 +2665,7 @@ class CustomVJPTest(jtu.JaxTestCase): ans = api.grad(api.grad(foo))(3.) expected = -2. * jnp.sin(3.) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def test_initial_style_vmap(self): @api.custom_vjp @@ -2851,7 +2836,7 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase): ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,))) val_ans, grad_ans = api.value_and_grad(foo)(3.) self.assertAllClose(val_ans, 9., check_dtypes=False) - self.assertAllClose(grad_ans, 12., check_dtypes=True) + self.assertAllClose(grad_ans, 12.) def test_defvjp_all_higher_order_revmode(self): foo_p = Primitive('foo') diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index ca0c218b3..6c61b9164 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -69,7 +69,7 @@ class DLPackTest(jtu.JaxTestCase): x = jnp.array(np) dlpack = jax.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) - self.assertAllClose(np.astype(x.dtype), y, check_dtypes=True) + self.assertAllClose(np.astype(x.dtype), y) self.assertRaisesRegex(RuntimeError, "DLPack tensor may be consumed at most once", @@ -89,7 +89,7 @@ class DLPackTest(jtu.JaxTestCase): x = x.cuda() if jtu.device_under_test() == "gpu" else x dlpack = torch.utils.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) - self.assertAllClose(np, y, check_dtypes=True) + self.assertAllClose(np, y) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( @@ -104,7 +104,7 @@ class DLPackTest(jtu.JaxTestCase): x = jnp.array(np) dlpack = jax.dlpack.to_dlpack(x) y = torch.utils.dlpack.from_dlpack(dlpack) - self.assertAllClose(np, y.numpy(), check_dtypes=True) + self.assertAllClose(np, y.numpy()) class CudaArrayInterfaceTest(jtu.JaxTestCase): @@ -128,7 +128,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase): z = cupy.asarray(y) self.assertEqual(y.__cuda_array_interface__["data"][0], z.__cuda_array_interface__["data"][0]) - self.assertAllClose(x, cupy.asnumpy(z), check_dtypes=True) + self.assertAllClose(x, cupy.asnumpy(z)) if __name__ == "__main__": diff --git a/tests/batching_test.py b/tests/batching_test.py index 480ddc0e5..ae839ee70 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -195,7 +195,7 @@ class BatchingTest(jtu.JaxTestCase): ans = vmap(lambda x: x > 1.0)(x) expected_ans = x > 1.0 - self.assertAllClose(ans, expected_ans, check_dtypes=True) + self.assertAllClose(ans, expected_ans) def testNpMaximumPerExampleGrad(self): R = np.random.RandomState(0).randn @@ -226,35 +226,35 @@ class BatchingTest(jtu.JaxTestCase): fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun)(x, y) expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) x = R(3, 4, 10, 5) y = R(3, 10, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(2, 1))(x, y) expected = np.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) x = R(3, 4, 5, 10) y = R(3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(3, None))(x, y) expected = np.stack([fun(x[..., i], y) for i in range(10)]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) x = R(3, 4, 5) y = R(3, 5, 10, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(None, 2))(x, y) expected = np.stack([fun(x, y[..., i, :]) for i in range(10)]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) x = R(4) y = R(4, 10) fun = lambda x, y: lax.dot_general(x, y, [((0,), (0,)), ((), ())]) ans = vmap(fun, in_axes=(None, 1))(x, y) expected = np.stack([fun(x, y[..., i]) for i in range(10)]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testDot(self): # these tests are based on @shoyer's notebook studying gufuncs @@ -348,7 +348,7 @@ class BatchingTest(jtu.JaxTestCase): ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]])) expected = jnp.array([True, False]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) @jtu.skip_on_devices("tpu") def testHessian(self): @@ -417,46 +417,44 @@ class BatchingTest(jtu.JaxTestCase): v = np.arange(12)[::-1].reshape(3, 4) sv = vmap(partial(lax.sort, dimension=0), (0,))(v) - self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) + self.assertAllClose(sv, v[:, ::-1]) sv = vmap(partial(lax.sort, dimension=-1), (0,))(v) - self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) + self.assertAllClose(sv, v[:, ::-1]) sv = vmap(partial(lax.sort, dimension=0), (1,))(v) - self.assertAllClose(sv, v[::-1, :].T, check_dtypes=True) + self.assertAllClose(sv, v[::-1, :].T) sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v) - self.assertAllClose(sv, v[::-1, :], check_dtypes=True) + self.assertAllClose(sv, v[::-1, :]) def testSortKeyVal(self): k = np.arange(12)[::-1].reshape(3, 4) v = np.random.RandomState(0).permutation(12).reshape(3, 4) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v) - self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) - self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) + self.assertAllClose(sk, k[:, ::-1]) + self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v) - self.assertAllClose(sk, k[::-1, :], check_dtypes=True) - self.assertAllClose(sv, v[::-1, :], check_dtypes=True) + self.assertAllClose(sk, k[::-1, :]) + self.assertAllClose(sv, v[::-1, :]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T) - self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) - self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) + self.assertAllClose(sk, k[:, ::-1]) + self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v) - self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) - self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) + self.assertAllClose(sk, k[:, ::-1]) + self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v) - self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)), - check_dtypes=True) - self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) + self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4))) + self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0]) - self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) - self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)), - check_dtypes=True) + self.assertAllClose(sk, k[:, ::-1]) + self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4))) def testConvGeneralDilated(self): W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32) @@ -474,7 +472,7 @@ class BatchingTest(jtu.JaxTestCase): per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) - self.assertAllClose(per_example, per_example_direct, check_dtypes=True) + self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) @@ -484,7 +482,7 @@ class BatchingTest(jtu.JaxTestCase): per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) - self.assertAllClose(per_example, per_example_direct, check_dtypes=True, + self.assertAllClose(per_example, per_example_direct, rtol=2e-2) def testConvGeneralDilatedBatchNotMajor(self): @@ -503,7 +501,7 @@ class BatchingTest(jtu.JaxTestCase): (5, 5, 21, 4)) per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)), (5, 21, 5, 1))) - self.assertAllClose(per_example, per_example_direct, check_dtypes=True) + self.assertAllClose(per_example, per_example_direct) @parameterized.named_parameters( {"testcase_name": "_op={}".format(name), "op": op, "unit": unit} @@ -526,7 +524,7 @@ class BatchingTest(jtu.JaxTestCase): per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) - self.assertAllClose(per_example, per_example_direct, check_dtypes=True) + self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) @@ -536,7 +534,7 @@ class BatchingTest(jtu.JaxTestCase): per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) - self.assertAllClose(per_example, per_example_direct, check_dtypes=True, + self.assertAllClose(per_example, per_example_direct, rtol=5e-2) def testSumPool(self): @@ -557,7 +555,7 @@ class BatchingTest(jtu.JaxTestCase): per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) - self.assertAllClose(per_example, per_example_direct, check_dtypes=True) + self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) @@ -567,14 +565,13 @@ class BatchingTest(jtu.JaxTestCase): per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) - self.assertAllClose(per_example, per_example_direct, check_dtypes=True, + self.assertAllClose(per_example, per_example_direct, rtol=3e-2) def testCumProd(self): x = jnp.arange(9).reshape(3, 3) + 1 y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x) - self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y, - check_dtypes=True) + self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y) def testSelect(self): pred = np.array([True, False]) @@ -582,7 +579,7 @@ class BatchingTest(jtu.JaxTestCase): on_false = np.array([2, 3]) ans = vmap(lax.select)(pred, on_true, on_false) expected = np.array([0, 3]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) pred = np.array([False, True]) on_true = np.array([0, 1]) @@ -590,28 +587,28 @@ class BatchingTest(jtu.JaxTestCase): ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false) expected = np.array([[2, 3], [0, 1]]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) pred = True on_true = np.array([0, 1], np.float32) on_false = np.array(3, np.float32) ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false) expected = np.array([0, 1], np.float32) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) pred = np.array([False, True]) on_true = np.array([0, 1], np.float32) on_false = np.array(3, np.float32) ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false) expected = np.array([3, 1], np.float32) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) pred = np.array([False, True]) on_true = np.array([2], np.float32) on_false = np.array([[3, 4]], np.float32) ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false) expected = np.array([[3, 2]], np.float32) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testLaxLinalgCholesky(self): a = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32) @@ -637,17 +634,17 @@ class BatchingTest(jtu.JaxTestCase): ans = vmap(lax_linalg.triangular_solve, in_axes=(1, 2))(a, b) expected = np.stack( [lax_linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) ans = vmap(lax_linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b) expected = np.stack( [lax_linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) ans = vmap(lax_linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0]) expected = np.stack( [lax_linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( @@ -987,7 +984,7 @@ class BatchingTest(jtu.JaxTestCase): g = jax.jit(jax.pmap(f)) ans = g(index1=np.asarray([1]), index2=np.asarray([2])) expected = g(np.asarray([1]), np.asarray([2])) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) if __name__ == '__main__': diff --git a/tests/callback_test.py b/tests/callback_test.py index aaa8d6025..2254d3e34 100644 --- a/tests/callback_test.py +++ b/tests/callback_test.py @@ -69,11 +69,11 @@ class CallbackTest(jtu.JaxTestCase): return x * 2 x = jnp.array([2.0, 4.0]) - self.assertAllClose(f(x), jnp.array([4.0, 8.0]), True) + self.assertAllClose(f(x), jnp.array([4.0, 8.0])) self.assertAllClose( rewrite(f, {lax.mul_p: lambda x, y: x + y})(x), - jnp.array([4.0, 6.0]), True) + jnp.array([4.0, 6.0])) def testRewriteJIT(self): def f(x): @@ -83,11 +83,11 @@ class CallbackTest(jtu.JaxTestCase): return g(x) x = jnp.array([2.0, 4.0]) - self.assertAllClose(f(x), jnp.array([4.0, 8.0]), True) + self.assertAllClose(f(x), jnp.array([4.0, 8.0])) self.assertAllClose( rewrite(f, {lax.mul_p: lambda x, y: x + y})(x), - jnp.array([4.0, 6.0]), True) + jnp.array([4.0, 6.0])) if __name__ == "__main__": absltest.main() diff --git a/tests/fft_test.py b/tests/fft_test.py index 2c702c6b7..7500af891 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -109,7 +109,7 @@ class FftTest(jtu.JaxTestCase): # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in (float_dtypes if real and not inverse else inexact_dtypes): # TODO(skye): can we be more precise? @@ -169,7 +169,7 @@ class FftTest(jtu.JaxTestCase): # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inverse={}_real={}".format(inverse, real), @@ -228,7 +228,7 @@ class FftTest(jtu.JaxTestCase): # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inverse={}_real={}".format(inverse, real), @@ -279,7 +279,7 @@ class FftTest(jtu.JaxTestCase): # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? @@ -323,7 +323,7 @@ class FftTest(jtu.JaxTestCase): # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? @@ -362,7 +362,7 @@ class FftTest(jtu.JaxTestCase): args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes) np_fn = lambda arg: np.fft.fftshift(arg, axes=axes) - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "dtype={}_axes={}".format( @@ -377,7 +377,7 @@ class FftTest(jtu.JaxTestCase): args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes) np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes) - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker) if __name__ == "__main__": absltest.main() diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index ed0e8b200..d0f5361d1 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -148,7 +148,7 @@ class HostCallbackTest(jtu.JaxTestCase): self.assertEqual("", testing_stream.output) with hcb.outfeed_receiver(): - self.assertAllClose((5. * 2.) ** 2, fun1(5.), check_dtypes=True) + self.assertAllClose((5. * 2.) ** 2, fun1(5.)) assertMultiLineStrippedEqual(self, """ what: a * 2 10.00 @@ -232,7 +232,7 @@ what: x3 with hcb.outfeed_receiver(receiver_name=self._testMethodName): res = jit_fun1(5.) - self.assertAllClose(6. * 5., res, check_dtypes=True) + self.assertAllClose(6. * 5., res) assertMultiLineStrippedEqual(self, """ what: here 10.00""", testing_stream.output) @@ -258,7 +258,7 @@ what: here self.assertEqual("", testing_stream.output) with hcb.outfeed_receiver(): - self.assertAllClose(5, api.jit(func)(5), check_dtypes=True) + self.assertAllClose(5, api.jit(func)(5)) assertMultiLineStrippedEqual(self, """ 42""", testing_stream.output) testing_stream.reset() @@ -512,7 +512,7 @@ where: 3 if with_jit: func = api.jit(func) res = func(1) - self.assertAllClose(jnp.array([1, 2, 3]), res, check_dtypes=True) + self.assertAllClose(jnp.array([1, 2, 3]), res) assertMultiLineStrippedEqual(self, """ where: 1 1 @@ -564,7 +564,7 @@ where: 10 testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}")) with hcb.outfeed_receiver(receiver_name=self._testMethodName): res = jit_fun1(args) - # self.assertAllClose(args, res, check_dtypes=True) + # self.assertAllClose(args, res) def test_jit_large(self): arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1)) diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 9555af666..746d5f0d5 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -44,7 +44,7 @@ class InfeedTest(jax.test_util.JaxTestCase): device = jax.local_devices()[0] device.transfer_to_infeed((y,)) device.transfer_to_infeed((z,)) - self.assertAllClose(f(x), x + y + z, check_dtypes=True) + self.assertAllClose(f(x), x + y + z) def testInfeedThenOutfeed(self): @jax.jit @@ -64,7 +64,7 @@ class InfeedTest(jax.test_util.JaxTestCase): out, = device.transfer_from_outfeed( xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent()) execution.join() - self.assertAllClose(out, y + onp.float32(1), check_dtypes=True) + self.assertAllClose(out, y + onp.float32(1)) def testInfeedThenOutfeedInALoop(self): def doubler(_, token): @@ -87,7 +87,7 @@ class InfeedTest(jax.test_util.JaxTestCase): device.transfer_to_infeed((x,)) y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,)) .with_major_to_minor_layout_if_absent()) - self.assertAllClose(y, x * onp.float32(2), check_dtypes=True) + self.assertAllClose(y, x * onp.float32(2)) execution.join() diff --git a/tests/jet_test.py b/tests/jet_test.py index 84f917daf..4be94576d 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -159,11 +159,9 @@ class JetTest(jtu.JaxTestCase): atol = 1e-4 rtol = 1e-4 - self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, - check_dtypes=True) + self.assertAllClose(y, expected_y, atol=atol, rtol=rtol) - self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol, - check_dtypes=True) + self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol) @jtu.skip_on_devices("tpu") diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d3ae78dc1..ef3a76b7f 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -404,7 +404,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): self.assertEqual(count(3), 3) self.assertEqual(count(4), 6) for args_maker in [lambda: [2], lambda: [3], lambda: [4]]: - self._CompileAndCheck(count, args_maker, True) + self._CompileAndCheck(count, args_maker) def testForiLoopClosure(self): def count(num): @@ -1356,13 +1356,13 @@ class LaxControlFlowTest(jtu.JaxTestCase): xs = jnp.arange(10) expected = xs ** 2 actual = lax.map(f, xs) - self.assertAllClose(actual, expected, check_dtypes=True) + self.assertAllClose(actual, expected) def testMapEmpty(self): # https://github.com/google/jax/issues/2412 ans = lax.map(lambda x: x * x, jnp.array([])) expected = jnp.array([]) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testCaching(self): def cond(x): @@ -1598,7 +1598,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): actual = api.jit(linear_solve)(a, b) expected = jnp.linalg.solve(a, b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) def test_custom_root_with_custom_linear_solve(self): @@ -1620,11 +1620,11 @@ class LaxControlFlowTest(jtu.JaxTestCase): actual = linear_solve(jnp.dot(a, a.T), b) expected = jnp.linalg.solve(jnp.dot(a, a.T), b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) actual = api.jit(linear_solve)(jnp.dot(a, a.T), b) expected = jnp.linalg.solve(jnp.dot(a, a.T), b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) jtu.check_grads(lambda x, y: linear_solve(jnp.dot(x, x.T), y), (a, b), order=2, rtol={jnp.float32: 1e-2}) @@ -1670,12 +1670,12 @@ class LaxControlFlowTest(jtu.JaxTestCase): expected = jnp.linalg.solve(a, b) actual = api.jit(linear_solve)(a, b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) c = rng.randn(3, 2) expected = jnp.linalg.solve(a, c) actual = api.vmap(linear_solve, (None, 1), 1)(a, c) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_custom_linear_solve_zeros(self): @@ -1723,7 +1723,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): b = rng.randn(2) expected = jnp.linalg.solve(jnp.exp(a), jnp.cos(b)) actual = build_and_solve(a, b) - self.assertAllClose(expected, actual, atol=1e-5, check_dtypes=True) + self.assertAllClose(expected, actual, atol=1e-5) jtu.check_grads(build_and_solve, (a, b), atol=1e-5, order=2, rtol={jnp.float32: 6e-2, jnp.float64: 2e-3}) @@ -1749,10 +1749,10 @@ class LaxControlFlowTest(jtu.JaxTestCase): expected = jnp.linalg.solve(np.asarray(posify(a)), b) actual = positive_definite_solve(posify(a), b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) actual = api.jit(positive_definite_solve)(posify(a), b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. @@ -1798,7 +1798,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): expected = jnp.linalg.solve(a, b) actual = linear_solve(a, b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3) diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index cec0fe352..f28f5cd64 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -33,7 +33,7 @@ class EinsumTest(jtu.JaxTestCase): def _check(self, s, *ops): a = np.einsum(s, *ops) b = jnp.einsum(s, *ops, precision=lax.Precision.HIGHEST) - self.assertAllClose(a, b, atol=1e-4, rtol=1e-4, check_dtypes=True) + self.assertAllClose(a, b, atol=1e-4, rtol=1e-4) def test_three_operands_1(self): r = self.rng() diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index afced4634..1f9b9bf0f 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -402,7 +402,7 @@ class IndexingTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype)] fun = lambda x: x[indexer] - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters({ "testcase_name": @@ -503,7 +503,7 @@ class IndexingTest(jtu.JaxTestCase): return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" @@ -553,7 +553,7 @@ class IndexingTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype), indexer] fun = lambda x, idx: jnp.asarray(x)[idx] - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" @@ -635,7 +635,7 @@ class IndexingTest(jtu.JaxTestCase): idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) return jnp.asarray(x)[idx] - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) def testAdvancedIndexingManually(self): x = onp.random.RandomState(0).randn(3, 4, 5) @@ -647,7 +647,7 @@ class IndexingTest(jtu.JaxTestCase): a1 = op(x, index_array) a2 = cop(x, index_array) - self.assertAllClose(a1, a2, check_dtypes=True) + self.assertAllClose(a1, a2) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = api.jit(op) @@ -655,7 +655,7 @@ class IndexingTest(jtu.JaxTestCase): a1 = op(x, index_array) a2 = cop(x, index_array) - self.assertAllClose(a1, a2, check_dtypes=True) + self.assertAllClose(a1, a2) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = api.jit(op) @@ -663,7 +663,7 @@ class IndexingTest(jtu.JaxTestCase): a1 = op(x, index_array) a2 = cop(x, index_array) - self.assertAllClose(a1, a2, check_dtypes=True) + self.assertAllClose(a1, a2) def testUnpacking(self): @@ -676,7 +676,7 @@ class IndexingTest(jtu.JaxTestCase): a1 = foo(onp.arange(3)) a2 = cfoo(onp.arange(3)) - self.assertAllClose(a1, a2, check_dtypes=True) + self.assertAllClose(a1, a2) def testBooleanIndexingArray1D(self): idx = onp.array([True, True, False]) @@ -739,8 +739,8 @@ class IndexingTest(jtu.JaxTestCase): primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i))) expected = onp.broadcast_to( onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4)) - self.assertAllClose(expected, primals, check_dtypes=True) - self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True) + self.assertAllClose(expected, primals) + self.assertAllClose(onp.zeros_like(x), tangents) def testTrivialGatherIsntGenerated(self): # https://github.com/google/jax/issues/1621 @@ -791,7 +791,7 @@ class IndexingTest(jtu.JaxTestCase): def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 array = jnp.ones(5) - self.assertAllClose(array, array[:10], check_dtypes=True) + self.assertAllClose(array, array[:10]) def _broadcastable_shapes(shape): @@ -877,8 +877,8 @@ class IndexedUpdateTest(jtu.JaxTestCase): jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y) else: jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) - self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) - self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker) + self._CompileAndCheck(jax_fn, args_maker) @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format( @@ -904,8 +904,8 @@ class IndexedUpdateTest(jtu.JaxTestCase): jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y) else: jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) - self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) - self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker) + self._CompileAndCheck(jax_fn, args_maker) @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( @@ -931,8 +931,8 @@ class IndexedUpdateTest(jtu.JaxTestCase): jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y) else: jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) - self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) - self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker) + self._CompileAndCheck(jax_fn, args_maker) @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 00637fc76..530799920 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -91,7 +91,8 @@ OpRecord = collections.namedtuple( "test_name", "check_dtypes", "tolerance", "inexact"]) def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name=None, check_dtypes=True, tolerance=None, inexact=False): + test_name=None, check_dtypes=True, + tolerance=None, inexact=False): test_name = test_name or name return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, test_name, check_dtypes, tolerance, inexact) @@ -497,7 +498,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): () in shapes) empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) self._CompileAndCheck( - fun, args_maker, check_dtypes=True, #not scalar_arg and not empty_shape, + fun, args_maker, #not scalar_arg and not empty_shape, atol=tol, rtol=tol) @parameterized.named_parameters(itertools.chain.from_iterable( @@ -525,7 +526,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): () in shapes) empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) self._CompileAndCheck( - fun, args_maker, check_dtypes=True, # not scalar_arg and not empty_shape, + fun, args_maker, # not scalar_arg and not empty_shape, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( @@ -577,7 +578,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, shapes, dtypes) self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( @@ -616,7 +617,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=jnp.bfloat16 not in (dtype, out_dtype), tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( @@ -646,8 +647,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}".format( @@ -661,7 +662,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp_fun = lambda x: jnp.count_nonzero(x, axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( @@ -733,13 +734,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] try: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) except ValueError as e: if str(e) == "All-NaN slice encountered": self.skipTest("JAX doesn't support checking for all-NaN slices") else: raise - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": rec.test_name.capitalize(), "name": rec.name, @@ -755,8 +756,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = partial(np_op, axis=0) jnp_fun = partial(jnp_op, axis=0) args_maker = lambda: [np.zeros((2, 0))] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format( @@ -793,9 +794,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15} tol = max(jtu.tolerance(lhs_dtype, tol_spec), jtu.tolerance(rhs_dtype, tol_spec)) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( @@ -830,9 +831,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype)) - self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol) - self._CompileAndCheck(jnp.dot, args_maker, check_dtypes=True, atol=tol, + self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( @@ -866,10 +867,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np.complex128: 1e-12} if jtu.device_under_test() == "tpu": tol[np.float32] = tol[np.complex64] = 4e-2 - self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, - check_dtypes=True, tol=tol) - self._CompileAndCheck(jnp.matmul, args_maker, check_dtypes=True, atol=tol, - rtol=tol) + self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol) + self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format( @@ -902,9 +901,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np.complex64: 1e-3, np.complex128: 1e-12} if jtu.device_under_test() == "tpu": tol[np.float32] = tol[np.complex64] = 2e-1 - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) def testTensordotErrors(self): a = np.random.random((3, 2, 2)) @@ -941,8 +940,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert) np_fun = lambda e, t: np.isin(e, t, invert=invert) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -960,8 +959,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] jnp_fun = lambda e, t: jnp.in1d(e, t, invert=invert) np_fun = lambda e, t: np.in1d(e, t, invert=invert) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -1015,7 +1014,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] # TODO(phawkins): the promotion behavior changed in Numpy 1.17. self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) def testClipError(self): with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"): @@ -1044,15 +1043,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testOperatorRound(self): self.assertAllClose(round(np.float32(7.532), 1), - round(jnp.float32(7.5), 1), check_dtypes=True) + round(jnp.float32(7.5), 1)) self.assertAllClose(round(np.float32(1.234), 2), - round(jnp.float32(1.234), 2), check_dtypes=True) + round(jnp.float32(1.234), 2)) self.assertAllClose(round(np.float32(1.234)), round(jnp.float32(1.234)), check_dtypes=False) self.assertAllClose(round(np.float32(7.532), 1), - round(jnp.array(7.5, jnp.float32), 1), check_dtypes=True) + round(jnp.array(7.5, jnp.float32), 1)) self.assertAllClose(round(np.float32(1.234), 2), - round(jnp.array(1.234, jnp.float32), 2), check_dtypes=True) + round(jnp.array(1.234, jnp.float32), 2)) self.assertAllClose(round(np.float32(1.234)), round(jnp.array(1.234, jnp.float32)), check_dtypes=False) @@ -1098,7 +1097,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape=[{}]_reps={}".format( @@ -1117,7 +1116,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( @@ -1128,7 +1127,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testExtract(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)] - self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}".format( @@ -1150,7 +1149,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = partial(np.compress, axis=axis) jnp_fun = partial(jnp.compress, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_condition=array[{}]_axis={}".format( @@ -1166,7 +1165,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [np.array(condition), rng(shape, dtype)] np_fun = partial(np.compress, axis=axis) jnp_fun = partial(jnp.compress, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( @@ -1193,8 +1192,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( @@ -1221,8 +1220,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape=[{}]_axis={}_repeats={}".format( @@ -1240,8 +1239,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_ind={}_inv={}_count={}".format( @@ -1260,7 +1259,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] np_fun = lambda x: np.unique(x, return_index, return_inverse, return_counts) jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) def testIssue1233(self): ''' @@ -1273,10 +1272,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): lax_ans = jnp.repeat(m, repeats, axis) numpy_ans = np.repeat(m, repeats, axis) - self.assertAllClose(lax_ans, numpy_ans, check_dtypes=True, rtol=tol, atol=tol) + self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol) jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) m = jnp.array([1,2,3,4,5,6]) args_maker = lambda: [m] @@ -1313,8 +1312,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): attempt_sideeffect(np_input) attempt_sideeffect(jnp_input) - self.assertAllClose(np_input, expected_np_input_after_call, check_dtypes=True) - self.assertAllClose(jnp_input, expected_jnp_input_after_call, check_dtypes=True) + self.assertAllClose(np_input, expected_np_input_after_call) + self.assertAllClose(jnp_input, expected_jnp_input_after_call) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format( @@ -1339,7 +1338,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp_fun = partial(jnp_op, mode=mode, precision=precision) tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format( @@ -1365,9 +1364,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): tol_thresholds = {dtypes.bfloat16: 4e-2} tol = max(jtu.tolerance(dtype, tol_thresholds), jtu.tolerance(out_dtype, tol_thresholds)) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -1411,8 +1410,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype) jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype) args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_k={}".format( @@ -1428,8 +1427,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda arg: getattr(np, op)(arg, k=k) jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_ndim={}_n={}".format(ndim, n), @@ -1453,8 +1452,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda arg: np.diag(arg, k) jnp_fun = lambda arg: jnp.diag(arg, k) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( @@ -1472,8 +1471,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2) jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_n={}".format(np.dtype(dtype).name, n), @@ -1484,8 +1483,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda: np.identity(n, dtype) jnp_fun = lambda: jnp.identity(n, dtype) args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_x1={}_x2={}_x1_rng={}".format( @@ -1515,8 +1514,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp_fun = lambda x1, x2: jnp.ldexp(x1, x2) args_maker = lambda: [x1_rng(x1_shape, x1_dtype), x2_rng(x2_shape, np.int32)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_x={}_rng_factory={}".format( @@ -1539,8 +1538,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda x: np.frexp(x) jnp_fun = lambda x: jnp.frexp(x) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -1565,8 +1564,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): return np.trace(arg, offset, axis1, axis2, out_dtype) jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_a={}_v={}_side={}".format( @@ -1585,8 +1584,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [jnp.sort(rng(ashape, dtype)), rng(vshape, dtype)] np_fun = lambda a, v: np.searchsorted(a, v, side=side) jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_x={}_bins={}_right={}_reverse={}".format( @@ -1607,8 +1606,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]] np_fun = lambda x, bins: np.digitize(x, bins, right=right) jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}".format( @@ -1629,7 +1628,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] np_fun = _promote_like_jnp(partial(np.stack, axis=axis)) jnp_fun = partial(jnp.stack, axis=axis) - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -1651,7 +1650,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] np_fun = _promote_like_jnp(getattr(np, op)) jnp_fun = getattr(jnp, op) - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outdtype={}".format( @@ -1667,8 +1666,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype) jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype) args_maker = lambda: [rng((), fill_value_dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list( @@ -1684,8 +1683,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def args_maker(): return [] np_op = partial(np_op, shape, dtype) jnp_op = partial(jnp_op, shape, dtype) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) def testOnesWithInvalidShape(self): with self.assertRaises(TypeError): @@ -1708,8 +1707,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda x, fill_value: np.full_like(x, fill_value, dtype=out_dtype) jnp_fun = lambda x, fill_value: jnp.full_like(x, fill_value, dtype=out_dtype) args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}_{}sections".format( @@ -1725,8 +1724,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda x: np.split(x, num_sections, axis=axis) jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) def testSplitTypeError(self): # If we pass an ndarray for indices_or_sections -> no error @@ -1774,7 +1773,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # linspace() compares poorly to numpy when using bfloat16 if dtype != jnp.bfloat16: self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( @@ -1808,7 +1807,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if dtype != jnp.bfloat16: self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}_{}sections".format( @@ -1833,8 +1832,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda x: fn(np, axis)(x, num_sections) jnp_fun = lambda x: fn(jnp, axis)(x, num_sections) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_order={}".format( @@ -1860,8 +1859,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda x: np.reshape(x, out_shape, order=order) jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order) args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}".format( @@ -1880,8 +1879,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda x: np.reshape(x, out_shape) jnp_fun = lambda x: x.reshape(*out_shape) args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_expanddim={!r}".format( @@ -1898,11 +1897,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda x: np.expand_dims(x, dim) jnp_fun = lambda x: jnp.expand_dims(x, dim) args_maker = lambda: [rng(arg_shape, dtype)] - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) if isinstance(dim, tuple) and numpy_version < (1, 18, 0): raise SkipTest("support for multiple axes added in NumPy 1.18.0") - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axes=({},{})".format( @@ -1918,8 +1917,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda x: np.swapaxes(x, ax1, ax2) jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2) args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axis={!r}".format( @@ -1940,8 +1939,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda x: np.squeeze(x, ax) jnp_fun = lambda x: jnp.squeeze(x, ax) args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format( @@ -2004,8 +2003,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): else: np_fun = partial(np.array, dtype=dtype) jnp_fun = jnp.array - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) def testArrayUnsupportedDtypeError(self): with self.assertRaisesRegex(TypeError, @@ -2042,8 +2041,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): ans = jnp.array(bytearray(b'\x2a')) self.assertAllClose( ans, - np.array([0x2a], dtype=np.uint8), - check_dtypes=True) + np.array([0x2a], dtype=np.uint8)) def testIsClose(self): c_isclose = api.jit(jnp.isclose) @@ -2195,8 +2193,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) jnp_op = lambda x: jnp.flip(x, axis) np_op = lambda x: np.flip(x, axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( @@ -2210,8 +2208,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) jnp_op = lambda x: jnp.flipud(x) np_op = lambda x: np.flipud(x) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -2226,8 +2224,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) jnp_op = lambda x: jnp.fliplr(x) np_op = lambda x: np.fliplr(x) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -2248,15 +2246,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) jnp_op = lambda x: jnp.rot90(x, k, axes) np_op = lambda x: np.rot90(x, k, axes) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) # TODO(mattjj): test infix operator overrides def testRavel(self): rng = np.random.RandomState(0) args_maker = lambda: [rng.randn(3, 4).astype("float32")] - self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True) + self._CompileAndCheck(lambda x: x.ravel(), args_maker) @parameterized.parameters( (0, (2, 1, 3)), @@ -2265,12 +2263,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): ([0, 1, 2], (2, 2)), ([[[0, 1], [2, 3]]], (2, 2))) def testUnravelIndex(self, flat_index, shape): - self._CheckAgainstNumpy( - np.unravel_index, - jnp.unravel_index, - lambda: (flat_index, shape), - check_dtypes=True - ) + self._CheckAgainstNumpy(np.unravel_index, jnp.unravel_index, + lambda: (flat_index, shape)) def testUnravelIndexOOB(self): self.assertEqual(jnp.unravel_index(2, (2,)), (1,)) @@ -2282,8 +2276,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng.randn(3, 4).astype("float32")] np_op = lambda x: np.asarray(x).astype(jnp.int32) jnp_op = lambda x: jnp.asarray(x).astype(jnp.int32) - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_dtype={}".format( @@ -2305,8 +2299,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp_op = lambda x: jnp.asarray(x).view(dtype) # Above may produce signaling nans; ignore warnings from invalid values. with np.errstate(invalid='ignore'): - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) def testPathologicalFloats(self): args_maker = lambda: [np.array([ @@ -2325,8 +2319,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_op = lambda x: np.asarray(x).view('float32').view('uint32') jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32') - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) # TODO(mattjj): test other ndarray-like method overrides @@ -2340,29 +2334,29 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # from https://github.com/google/jax/issues/145 expected = np.arange(0.0, 1.0, 0.1, dtype=jnp.float_) ans = jnp.arange(0.0, 1.0, 0.1) - self.assertAllClose(expected, ans, check_dtypes=True) + self.assertAllClose(expected, ans) def testSortManually(self): # manual tests for sort are nice because we don't have to worry about ties. # lax.sort is tested combinatorially. ans = jnp.sort(np.array([16, 15, 23, 42, 8, 4])) expected = np.array([4, 8, 15, 16, 23, 42]) - self.assertAllClose(expected, ans, check_dtypes=True) + self.assertAllClose(expected, ans) a = np.array([[1, 4], [3, 1]]) ans = jnp.sort(a, axis=None) expected = np.array([1, 1, 3, 4]) - self.assertAllClose(expected, ans, check_dtypes=True) + self.assertAllClose(expected, ans) a = np.array([[1, 4], [3, 1]]) ans = jnp.sort(a) # last axis expected = np.array([[1, 4], [1, 3]]) - self.assertAllClose(expected, ans, check_dtypes=True) + self.assertAllClose(expected, ans) a = np.array([[1, 4], [3, 1]]) ans = jnp.sort(a, axis=0) expected = np.array([[1, 1], [3, 4]]) - self.assertAllClose(expected, ans, check_dtypes=True) + self.assertAllClose(expected, ans) def testArgsortManually(self): x = np.array([16, 15, 23, 42, 8, 4]) @@ -2394,8 +2388,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [np.random.randint(50, size=(5 ,5))] jnp_op = lambda x: jnp.msort(x) np_op = lambda x: np.msort(x) - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_shifts={}_axis={}".format( @@ -2420,8 +2414,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype), np.array(shifts)] jnp_op = partial(jnp.roll, axis=axis) np_op = partial(np.roll, axis=axis) - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}_start={}".format( @@ -2439,8 +2433,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] jnp_op = partial(jnp.rollaxis, axis=axis, start=start) np_op = partial(np.rollaxis, axis=axis, start=start) - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}_bitorder={}".format( @@ -2459,8 +2453,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder) np_op = partial(np.packbits, axis=axis, bitorder=bitorder) - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}_bitorder={}_count={}".format( @@ -2479,8 +2473,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder) np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder) - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_index={}_axis={}_mode={}".format( @@ -2506,8 +2500,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): rng_indices = jtu.rand_int(self.rng(), -5, 5) jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode) np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode) - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_ishape={}_axis={}".format( @@ -2541,8 +2535,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if hasattr(np, "take_along_axis"): np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(jnp_op, np_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_n={}_increasing={}".format( @@ -2606,9 +2600,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] - self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker, - check_dtypes=True) - self._CompileAndCheck(jnp.ix_, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker) + self._CompileAndCheck(jnp.ix_, args_maker) @parameterized.named_parameters( jtu.cases_from_list( @@ -2630,8 +2623,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): dtype=dtype, sparse=sparse) jnp_fun = partial(jnp.indices, dimensions=dimensions, dtype=dtype, sparse=sparse) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -2681,7 +2674,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jtu.tolerance(q_dtype, tol_spec)) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol) + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( @@ -2714,7 +2707,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): tol = jtu.tolerance(a_dtype, tol_spec) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol) + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( @@ -2745,9 +2738,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng_factory(self.rng()), shapes, dtypes) def np_fun(cond, x, y): return _promote_like_jnp(partial(np.where, cond))(x, y) - self._CheckAgainstNumpy(np_fun, jnp.where, args_maker, - check_dtypes=True) - self._CompileAndCheck(jnp.where, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp.where, args_maker) + self._CompileAndCheck(jnp.where, args_maker) def testWhereScalarPromotion(self): x = jnp.where(jnp.array([True, False]), 3, @@ -2782,7 +2774,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np.asarray(default, dtype=dtype)) self._CheckAgainstNumpy(np_fun, jnp.select, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp.select, args_maker, check_dtypes=True, + self._CompileAndCheck(jnp.select, args_maker, rtol={np.float64: 1e-7, np.complex128: 1e-7}) @@ -2823,7 +2815,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): a = np.arange(6) + 1 ans = jnp.reshape(a, (3, 2), order='F') expected = np.reshape(a, (3, 2), order='F') - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_dtype={}".format(op, pytype.__name__), @@ -2836,8 +2828,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda arg: getattr(np, op)(arg).astype(dtype) jnp_fun = lambda arg: getattr(jnp, op)(arg) args_maker = lambda: [pytype(2)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( { @@ -2865,7 +2857,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp_fun = partial(jnp.bincount, minlength=minlength, length=length) if length is not None: - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) if length is None: self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) @@ -2906,32 +2898,29 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): ]))) def testBlock(self, input): args_maker = lambda: [input] - self._CheckAgainstNumpy(np.block, jnp.block, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp.block, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np.block, jnp.block, args_maker) + self._CompileAndCheck(jnp.block, args_maker) def testLongLong(self): - self.assertAllClose(np.int64(7), api.jit(lambda x: x)(np.longlong(7)), - check_dtypes=True) + self.assertAllClose(np.int64(7), api.jit(lambda x: x)(np.longlong(7))) def testArange(self): # test cases inspired by dask tests at # https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92 self.assertAllClose(jnp.arange(77), - np.arange(77, dtype=jnp.int_), check_dtypes=True) + np.arange(77, dtype=jnp.int_)) self.assertAllClose(jnp.arange(2, 13), - np.arange(2, 13, dtype=jnp.int_), check_dtypes=True) + np.arange(2, 13, dtype=jnp.int_)) self.assertAllClose(jnp.arange(4, 21, 9), - np.arange(4, 21, 9, dtype=jnp.int_), check_dtypes=True) + np.arange(4, 21, 9, dtype=jnp.int_)) self.assertAllClose(jnp.arange(53, 5, -3), - np.arange(53, 5, -3, dtype=jnp.int_), - check_dtypes=True) + np.arange(53, 5, -3, dtype=jnp.int_)) self.assertAllClose(jnp.arange(77, dtype=float), - np.arange(77, dtype=float), check_dtypes=True) + np.arange(77, dtype=float)) self.assertAllClose(jnp.arange(2, 13, dtype=int), - np.arange(2, 13, dtype=int), check_dtypes=True) + np.arange(2, 13, dtype=int)) self.assertAllClose(jnp.arange(0, 1, -0.5), - np.arange(0, 1, -0.5, dtype=jnp.float_), - check_dtypes=True) + np.arange(0, 1, -0.5, dtype=jnp.float_)) self.assertRaises(TypeError, lambda: jnp.arange()) @@ -2976,8 +2965,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # argument. return lax.tie_in(y, 7.) - self.assertAllClose(np.zeros(3,), api.grad(f)(np.ones(3,)), - check_dtypes=True) + self.assertAllClose(np.zeros(3,), api.grad(f)(np.ones(3,))) # NOTE(mattjj): I disabled this test when removing lax._safe_mul because this # is a numerical stability issue that should be solved with a custom jvp rule @@ -2985,8 +2973,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # def testIssue777(self): # x = jnp.linspace(-200, 0, 4, dtype=np.float32) # f = api.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x)))) - # self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32), - # check_dtypes=True) + # self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32)) @parameterized.named_parameters( jtu.cases_from_list( @@ -3017,7 +3004,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): expected = np_op(x) actual = jnp_op(x) tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7}) - self.assertAllClose(expected, actual, check_dtypes=True, atol=tol, + self.assertAllClose(expected, actual, atol=tol, rtol=tol) def testIssue883(self): @@ -3067,9 +3054,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): not jnp.issubdtype(out_dtype, jnp.complexfloating)): self.assertRaises(ValueError, lambda: jnp_fun(*args_maker())) else: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol, + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( @@ -3100,9 +3087,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): not jnp.issubdtype(out_dtype, jnp.complexfloating)): self.assertRaises(ValueError, lambda: jnp_fun(*args_maker())) else: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol, + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( @@ -3128,7 +3115,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): tol = jtu.join_tolerance(tol, jtu.tolerance(dtype)) self._CheckAgainstNumpy( np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) def testIssue967(self): @@ -3157,7 +3144,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy( np_fun, jnp_fun, args_maker, check_dtypes=False, tol=1e-2 if jtu.device_under_test() == "tpu" else None) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype), @@ -3184,8 +3171,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): (None if begin_dtype is None else rng(begin_shape, begin_dtype))] np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) def testEDiff1dWithDtypeCast(self): rng = jtu.rand_default(self.rng()) @@ -3196,8 +3183,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)] np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list( @@ -3216,7 +3203,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): [dtype] * len(shapes)) np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse) jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list( @@ -3272,7 +3259,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): rng = rng_factory(self.rng()) endpoints = rng((2,), dtype) out = jnp.linspace(*endpoints, 10, dtype=dtype) - self.assertAllClose(out[[0, -1]], endpoints, check_dtypes=True, rtol=0, atol=0) + self.assertAllClose(out[[0, -1]], endpoints, rtol=0, atol=0) @parameterized.named_parameters( jtu.cases_from_list( @@ -3454,8 +3441,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32]) np_op = lambda x: np.broadcast_to(x, to_shape) jnp_op = lambda x: jnp.broadcast_to(x, to_shape) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) def testBroadcastToIssue1522(self): self.assertRaisesRegex( @@ -3540,7 +3527,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_fun = lambda y: np.gradient(y, *varargs, axis=axis) self._CheckAgainstNumpy( np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) def testZerosShapeErrors(self): # see https://github.com/google/jax/issues/1822 @@ -3556,9 +3543,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testTraceMethod(self): x = self.rng().randn(3, 4).astype(jnp.float_) - self.assertAllClose(x.trace(), jnp.array(x).trace(), check_dtypes=True) - self.assertAllClose(x.trace(), api.jit(lambda y: y.trace())(x), - check_dtypes=True) + self.assertAllClose(x.trace(), jnp.array(x).trace()) + self.assertAllClose(x.trace(), api.jit(lambda y: y.trace())(x)) def testIntegerPowersArePrecise(self): # See https://github.com/google/jax/pull/3036 diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index ca48e0a34..7c78da423 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -141,8 +141,8 @@ class VectorizeTest(jtu.JaxTestCase): return y x = jnp.arange(3) - self.assertAllClose(x, f('foo', x), check_dtypes=True) - self.assertAllClose(x, jax.jit(f, 0)('foo', x), check_dtypes=True) + self.assertAllClose(x, f('foo', x)) + self.assertAllClose(x, jax.jit(f, 0)('foo', x)) def test_exclude_second(self): @@ -153,8 +153,8 @@ class VectorizeTest(jtu.JaxTestCase): return x x = jnp.arange(3) - self.assertAllClose(x, f(x, 'foo'), check_dtypes=True) - self.assertAllClose(x, jax.jit(f, 1)(x, 'foo'), check_dtypes=True) + self.assertAllClose(x, f(x, 'foo')) + self.assertAllClose(x, jax.jit(f, 1)(x, 'foo')) def test_exclude_errors(self): with self.assertRaisesRegex( diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 549e5a794..1a8c2439a 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -92,7 +92,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase): partial(scipy_cg, M=M, maxiter=1), partial(lax_cg, M=M, maxiter=1), args_maker, - check_dtypes=True, tol=1e-3) # TODO(shoyer,mattjj): I had to loosen the tolerance for complex64[7,7] @@ -101,14 +100,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase): partial(scipy_cg, M=M, maxiter=3), partial(lax_cg, M=M, maxiter=3), args_maker, - check_dtypes=True, tol=3e-3) self._CheckAgainstNumpy( np.linalg.solve, partial(lax_cg, M=M, atol=1e-6), args_maker, - check_dtypes=True, tol=2e-2) @parameterized.named_parameters(jtu.cases_from_list( @@ -125,10 +122,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase): expected = np.linalg.solve(posify(a), b) actual = lax_cg(posify(a), b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) actual = jit(lax_cg)(posify(a), b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. @@ -141,7 +138,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): b = jnp.arange(9.0).reshape((3, 3)) expected = b / 2 actual, _ = jax.scipy.sparse.linalg.cg(A, b) - self.assertAllClose(expected, actual, check_dtypes=True) + self.assertAllClose(expected, actual) def test_cg_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index a300aa6ac..05675e53f 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -108,8 +108,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( @@ -129,7 +129,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) - self._CompileAndCheck(lax_op, args_maker, check_dtypes=True, rtol=1e-5) + self._CompileAndCheck(lax_op, args_maker, rtol=1e-5) if test_autodiff: jtu.check_grads(lax_op, args, order=1, @@ -153,14 +153,14 @@ class LaxBackedScipyTests(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={onp.float32: 1e-3, onp.float64: 1e-14}) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) def testIssue980(self): x = onp.full((4,), -1e20, dtype=onp.float32) self.assertAllClose(onp.zeros((4,), dtype=onp.float32), - lsp_special.expit(x), check_dtypes=True) + lsp_special.expit(x)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) diff --git a/tests/lax_test.py b/tests/lax_test.py index abaaf934d..56ec8e85a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -187,7 +187,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype) for shape in shapes] op = getattr(lax, op_name) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( @@ -222,7 +222,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng((2, 3), from_dtype)] op = lambda x: lax.convert_element_type(x, to_dtype) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_from_dtype={}_to_dtype={}" @@ -249,7 +249,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng((2, 3), from_dtype)] op = lambda x: lax.bitcast_convert_type(x, to_dtype) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_from_dtype={}_to_dtype={}" @@ -284,7 +284,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) shapes = [min_shape, operand_shape, max_shape] args_maker = lambda: [rng(shape, dtype) for shape in shapes] - self._CompileAndCheck(lax.clamp, args_maker, check_dtypes=True) + self._CompileAndCheck(lax.clamp, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format( @@ -325,7 +325,7 @@ class LaxTest(jtu.JaxTestCase): for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] args_maker = lambda: [rng(shape, dtype) for shape in shapes] op = lambda *args: lax.concatenate(args, dim) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format( @@ -368,7 +368,7 @@ class LaxTest(jtu.JaxTestCase): def fun(lhs, rhs): return lax.conv(lhs, rhs, strides, padding) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -419,7 +419,7 @@ class LaxTest(jtu.JaxTestCase): return lax.conv_with_general_padding( lhs, rhs, strides, padding, lhs_dilation, rhs_dilation) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" @@ -502,7 +502,7 @@ class LaxTest(jtu.JaxTestCase): dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) # TODO(mattjj): test conv_general_dilated against numpy @@ -512,7 +512,7 @@ class LaxTest(jtu.JaxTestCase): return [rng((10, 5), onp.float32), rng((5, 7), onp.float32)] jnp_fun = partial(lax.conv_general_dilated, window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker) self._CheckAgainstNumpy(jnp_fun, onp.dot, args_maker, tol=.1) @@ -685,8 +685,7 @@ class LaxTest(jtu.JaxTestCase): def testDot(self, lhs_shape, rhs_shape, dtype, precision, rng_factory): rng = rng_factory(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker, - check_dtypes=True) + self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}".format( @@ -798,7 +797,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype)] op = lambda x: lax.broadcast(x, broadcast_sizes) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_broadcast_sizes={}".format( @@ -835,7 +834,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(inshape, dtype)] op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( @@ -921,7 +920,7 @@ class LaxTest(jtu.JaxTestCase): args_maker = lambda: [rng(arg_shape, onp.float32)] op = lambda x: lax.squeeze(x, dimensions) numpy_op = lambda x: lax_reference.squeeze(x, dimensions) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) self._CheckAgainstNumpy(op, numpy_op, args_maker) check_grads(op, args_maker(), 2, ["fwd", "rev"], eps=1.) @@ -940,7 +939,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(arg_shape, dtype)] op = lambda x: lax.reshape(x, out_shape) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}".format( @@ -971,7 +970,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype)] fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}" @@ -1018,7 +1017,7 @@ class LaxTest(jtu.JaxTestCase): return [rng(pred_shape, onp.bool_), rng(arg_shape, arg_dtype), rng(arg_shape, arg_dtype)] rng = rng_factory(self.rng()) - return self._CompileAndCheck(lax.select, args_maker, check_dtypes=True) + return self._CompileAndCheck(lax.select, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}".format( @@ -1061,7 +1060,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype)] op = lambda x: lax.slice(x, starts, limits, strides) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -1109,7 +1108,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)] op = lambda x, starts: lax.dynamic_slice(x, starts, size_indices) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( @@ -1159,8 +1158,7 @@ class LaxTest(jtu.JaxTestCase): return [rng(shape, dtype), rng(update_shape, dtype), onp.array(start_indices)] - self._CompileAndCheck(lax.dynamic_update_slice, args_maker, - check_dtypes=True) + self._CompileAndCheck(lax.dynamic_update_slice, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( @@ -1202,7 +1200,7 @@ class LaxTest(jtu.JaxTestCase): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype)] op = lambda x: lax.transpose(x, perm) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}".format( @@ -1257,13 +1255,13 @@ class LaxTest(jtu.JaxTestCase): init_val = onp.asarray(init_val, dtype=dtype) fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims) args_maker = lambda: [rng(shape, dtype), init_val] - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) # we separately test the version that uses a concrete init_val because it # can hit different code paths fun = lambda operand: lax.reduce(operand, init_val, op, dims) args_maker = lambda: [rng(shape, dtype)] - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_dtype={}_padding={}" @@ -1297,7 +1295,7 @@ class LaxTest(jtu.JaxTestCase): # pylint: disable=cell-var-from-loop for shape, dims, strides in all_configs: args_maker = lambda: [rng(shape, dtype), init_val] - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) # pylint: enable=cell-var-from-loop # we separately test the version that uses a concrete init_val because it @@ -1308,7 +1306,7 @@ class LaxTest(jtu.JaxTestCase): # pylint: disable=cell-var-from-loop for shape, dims, strides in all_configs: args_maker = lambda: [rng(shape, dtype)] - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) # pylint: enable=cell-var-from-loop @parameterized.named_parameters(jtu.cases_from_list( @@ -1329,9 +1327,9 @@ class LaxTest(jtu.JaxTestCase): def testCumulativeReduce(self, op, onp_op, shape, dtype, axis, rng_factory): rng = rng_factory(self.rng()) fun = partial(op, axis=axis) - onp_fun = partial(onp_op, axis=axis) + onp_fun = partial(onp_op, axis=axis, dtype=dtype) args_maker = lambda: [rng(shape, dtype)] - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) self._CheckAgainstNumpy(fun, onp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -1350,7 +1348,7 @@ class LaxTest(jtu.JaxTestCase): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] fun = lambda x: lax.sort(x, dimension=axis) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}".format( @@ -1399,7 +1397,7 @@ class LaxTest(jtu.JaxTestCase): return keys, values fun = lambda keys, values: lax.sort_key_val(keys, values, axis) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_keyshape={}_valshape={}_axis={}".format( @@ -1447,12 +1445,12 @@ class LaxTest(jtu.JaxTestCase): values = self.rng().permutation(flat_values).reshape(shape) return [values] def reference_top_k(x): - bcast_idxs = onp.broadcast_to(onp.arange(shape[-1]), shape) + bcast_idxs = onp.broadcast_to(onp.arange(shape[-1], dtype=onp.int32), shape) sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs) return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1] op = lambda vs: lax.top_k(vs, k=k) self._CheckAgainstNumpy(op, reference_top_k, args_maker) - self._CompileAndCheck(op, args_maker, check_dtypes=True) + self._CompileAndCheck(op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}" @@ -1468,7 +1466,7 @@ class LaxTest(jtu.JaxTestCase): def testBatchMatMul(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory(self.rng()) arg_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CompileAndCheck(lax.batch_matmul, arg_maker, check_dtypes=True) + self._CompileAndCheck(lax.batch_matmul, arg_maker) def testCollapse(self): @@ -1498,7 +1496,7 @@ class LaxTest(jtu.JaxTestCase): rand_idxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs) args_maker = lambda: [rng(shape, dtype), rand_idxs()] fun = lambda src, idxs: lax.index_take(src, idxs, axes) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format( @@ -1531,7 +1529,7 @@ class LaxTest(jtu.JaxTestCase): rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype) args_maker = lambda: [rng(shape, dtype), rand_idxs()] fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( @@ -1562,7 +1560,7 @@ class LaxTest(jtu.JaxTestCase): args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(), rng(update_shape, dtype)] fun = partial(lax.scatter_add, dimension_numbers=dnums) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( @@ -1593,7 +1591,7 @@ class LaxTest(jtu.JaxTestCase): args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(), rng(update_shape, dtype)] fun = partial(lax.scatter_min, dimension_numbers=dnums) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( @@ -1624,7 +1622,7 @@ class LaxTest(jtu.JaxTestCase): args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(), rng(update_shape, dtype)] fun = partial(lax.scatter_max, dimension_numbers=dnums) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( @@ -1655,7 +1653,7 @@ class LaxTest(jtu.JaxTestCase): args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(), rng(update_shape, dtype)] fun = partial(lax.scatter, dimension_numbers=dnums) - self._CompileAndCheck(fun, args_maker, check_dtypes=True) + self._CompileAndCheck(fun, args_maker) def testIssue831(self): # Tests the DeviceTuple constant handler @@ -1667,7 +1665,7 @@ class LaxTest(jtu.JaxTestCase): def testReshapeWithUnusualShapes(self): ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1)) - self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True) + self.assertAllClose(ans, onp.ones((3, 1), onp.float32)) self.assertRaisesRegex( TypeError, @@ -1722,9 +1720,9 @@ class LazyConstantTest(jtu.JaxTestCase): jit_result = api.jit(lambda x: lax.add(x, make_const()))(zero) # ensure they're all the same - self.assertAllClose(asarray_result, expected, check_dtypes=True) - self.assertAllClose(argument_result, expected, check_dtypes=True) - self.assertAllClose(jit_result, expected, check_dtypes=True) + self.assertAllClose(asarray_result, expected) + self.assertAllClose(argument_result, expected) + self.assertAllClose(jit_result, expected) # ensure repr doesn't crash repr(make_const()) @@ -2748,11 +2746,11 @@ class LaxAutodiffTest(jtu.JaxTestCase): x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = onp.array(0.0) @@ -2784,8 +2782,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) return 1 / x grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv)))))) - self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)), - check_dtypes=True) + self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.))) def all_bdims(*shapes): bdims = (itertools.chain([cast(Optional[int], None)], @@ -2819,7 +2816,7 @@ class LaxVmapTest(jtu.JaxTestCase): args_slice = args_slicer(args, bdims) ans = api.vmap(op, bdims)(*args) expected = onp.stack([op(*args_slice(i)) for i in range(bdim_size)]) - self.assertAllClose(ans, expected, check_dtypes=True, rtol=rtol, atol=atol) + self.assertAllClose(ans, expected, rtol=rtol, atol=atol) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 2a46f4416..509a74feb 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -71,8 +71,8 @@ class NumpyLinalgTest(jtu.JaxTestCase): self.skipTest("Unimplemented case for complex Cholesky decomposition.") self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker, - check_dtypes=True, tol=1e-3) - self._CompileAndCheck(jnp.linalg.cholesky, args_maker, check_dtypes=True) + tol=1e-3) + self._CompileAndCheck(jnp.linalg.cholesky, args_maker) if jnp.finfo(dtype).bits == 64: jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2) @@ -96,14 +96,13 @@ class NumpyLinalgTest(jtu.JaxTestCase): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] - self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, - check_dtypes=True, tol=1e-3) - self._CompileAndCheck(jnp.linalg.det, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, tol=1e-3) + self._CompileAndCheck(jnp.linalg.det, args_maker, rtol={np.float64: 1e-13, np.complex128: 1e-13}) def testDetOfSingularMatrix(self): x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32) - self.assertAllClose(np.float32(0), jsp.linalg.det(x), check_dtypes=True) + self.assertAllClose(np.float32(0), jsp.linalg.det(x)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -177,10 +176,9 @@ class NumpyLinalgTest(jtu.JaxTestCase): self._CheckAgainstNumpy(np.linalg.tensorsolve, jnp.linalg.tensorsolve, args_maker, - check_dtypes=True, tol={np.float32: 1e-2, np.float64: 1e-3}) self._CompileAndCheck(jnp.linalg.tensorsolve, - args_maker, check_dtypes=True, + args_maker, rtol={np.float64: 1e-13}) @parameterized.named_parameters(jtu.cases_from_list( @@ -198,8 +196,8 @@ class NumpyLinalgTest(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker, - check_dtypes=True, tol=1e-3) - self._CompileAndCheck(jnp.linalg.slogdet, args_maker, check_dtypes=True) + tol=1e-3) + self._CompileAndCheck(jnp.linalg.slogdet, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -221,7 +219,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): mat = jnp.array([np.diag(np.ones([5], dtype=np.float32))*(-.01)] * 2) args_maker = lambda: [mat] self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker, - check_dtypes=True, tol=1e-3) + tol=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( @@ -249,7 +247,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): self.assertTrue(np.all(norm(np.matmul(a, v) - w[..., None, :] * v) < 100)) self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, - check_dtypes=True, rtol=1e-3) + rtol=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( @@ -269,7 +267,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): a, = args_maker() w1, _ = jnp.linalg.eig(a) w2 = jnp.linalg.eigvals(a) - self.assertAllClose(w1, w2, check_dtypes=True) + self.assertAllClose(w1, w2) @jtu.skip_on_devices("gpu", "tpu") def testEigvalsInf(self): @@ -328,7 +326,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): self.assertTrue(norm(np.matmul(a, v) - w * v) < tol) self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker, - check_dtypes=True, rtol=1e-3) + rtol=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( @@ -349,7 +347,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): a = (a + np.conj(a.T)) / 2 return [a] self._CheckAgainstNumpy(np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker, - check_dtypes=True, tol=1e-3) + tol=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -478,9 +476,9 @@ class NumpyLinalgTest(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) np_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) - self._CheckAgainstNumpy(np_fn, np_fn, args_maker, - check_dtypes=False, tol=1e-3) - self._CompileAndCheck(np_fn, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(np_fn, np_fn, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(np_fn, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format( @@ -533,7 +531,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False), np.asarray(out), atol=1e-4, rtol=1e-4)) self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv), - args_maker, check_dtypes=True) + args_maker) if not (compute_uv and full_matrices): svd = partial(jnp.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv) @@ -698,8 +696,8 @@ class NumpyLinalgTest(jtu.JaxTestCase): args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, - check_dtypes=True, tol=1e-3) - self._CompileAndCheck(jnp.linalg.solve, args_maker, check_dtypes=True) + tol=1e-3) + self._CompileAndCheck(jnp.linalg.solve, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -726,8 +724,8 @@ class NumpyLinalgTest(jtu.JaxTestCase): return [a] self._CheckAgainstNumpy(np.linalg.inv, jnp.linalg.inv, args_maker, - check_dtypes=True, tol=1e-3) - self._CompileAndCheck(jnp.linalg.inv, args_maker, check_dtypes=True) + tol=1e-3) + self._CompileAndCheck(jnp.linalg.inv, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -743,8 +741,8 @@ class NumpyLinalgTest(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker, - check_dtypes=True, tol=1e-2) - self._CompileAndCheck(jnp.linalg.pinv, args_maker, check_dtypes=True) + tol=1e-2) + self._CompileAndCheck(jnp.linalg.pinv, args_maker) # TODO(phawkins): 1e-1 seems like a very loose tolerance. jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1) @@ -754,15 +752,13 @@ class NumpyLinalgTest(jtu.JaxTestCase): a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2) return jnp.linalg.pinv(a) j = jax.jacobian(f)(jnp.float32(2.)) - self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j, - check_dtypes=True) + self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j) expected = jnp.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]], [[[0., 0.], [-1., 0.]], [[0., 0.], [0., -1.]]]], dtype=jnp.float32) self.assertAllClose( - expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)), - check_dtypes=True) + expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32))) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_n={}".format( @@ -781,9 +777,9 @@ class NumpyLinalgTest(jtu.JaxTestCase): tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3 self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n), partial(jnp.linalg.matrix_power, n=n), - args_maker, check_dtypes=True, tol=tol) + args_maker, tol=tol) self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker, - check_dtypes=True, rtol=1e-3) + rtol=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( @@ -825,9 +821,8 @@ class NumpyLinalgTest(jtu.JaxTestCase): tol = {np.float32: 1e-4, np.float64: 1e-10, np.complex64: 1e-4, np.complex128: 1e-10} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( @@ -865,7 +860,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): return [lhs, rng(rhs_shape, dtype)] self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol) + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) # Disabled because grad is flaky for low-rank inputs. # TODO: @@ -880,7 +875,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): grad_test_jc = jit(grad(jit(test))) xc = np.eye(3, dtype=np.complex) - self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True) + self.assertAllClose(xc, grad_test_jc(xc)) @jtu.skip_on_flag("jax_skip_slow_tests", True) def testIssue1151(self): @@ -888,8 +883,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32) b = jnp.array(rng.randn(100, 3), dtype=jnp.float32) x = jnp.linalg.solve(A, b) - self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2, - check_dtypes=True) + self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2) jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A, b) jac1 = jax.jacobian(jnp.linalg.solve, argnums=1)(A, b) jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A[0], b[0]) @@ -926,8 +920,8 @@ class ScipyLinalgTest(jtu.JaxTestCase): def testBlockDiag(self, args): args_maker = lambda: args self._CheckAgainstNumpy(osp.linalg.block_diag, jsp.linalg.block_diag, - args_maker, check_dtypes=True) - self._CompileAndCheck(jsp.linalg.block_diag, args_maker, check_dtypes=True) + args_maker) + self._CompileAndCheck(jsp.linalg.block_diag, args_maker) @parameterized.named_parameters(jtu.cases_from_list( @@ -943,15 +937,15 @@ class ScipyLinalgTest(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype)] x, = args_maker() p, l, u = jsp.linalg.lu(x) - self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True, + self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), rtol={np.float32: 1e-3, np.float64: 1e-12, np.complex64: 1e-3, np.complex128: 1e-12}) - self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True) + self._CompileAndCheck(jsp.linalg.lu, args_maker) def testLuOfSingularMatrix(self): x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32) p, l, u = jsp.linalg.lu(x) - self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True) + self.assertAllClose(x, np.matmul(p, np.matmul(l, u))) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -986,9 +980,9 @@ class ScipyLinalgTest(jtu.JaxTestCase): us = np.stack([out[2] for out in expected]) actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args)) - self.assertAllClose(ps, actual_ps, check_dtypes=True) - self.assertAllClose(ls, actual_ls, check_dtypes=True) - self.assertAllClose(us, actual_us, check_dtypes=True) + self.assertAllClose(ps, actual_ps) + self.assertAllClose(ls, actual_ls) + self.assertAllClose(us, actual_us) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -1008,9 +1002,9 @@ class ScipyLinalgTest(jtu.JaxTestCase): u = np.triu(lu) for i in range(n): x[[i, piv[i]],] = x[[piv[i], i],] - self.assertAllClose(x, np.matmul(l, u), check_dtypes=True, rtol=1e-3, + self.assertAllClose(x, np.matmul(l, u), rtol=1e-3, atol=1e-3) - self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True) + self._CompileAndCheck(jsp.linalg.lu_factor, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -1040,9 +1034,8 @@ class ScipyLinalgTest(jtu.JaxTestCase): lu, piv = osp.linalg.lu_factor(a) return [lu, piv, rng(rhs_shape, dtype)] - self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, - check_dtypes=True, tol=1e-3) - self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3) + self._CompileAndCheck(jsp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -1081,9 +1074,8 @@ class ScipyLinalgTest(jtu.JaxTestCase): a = np.tril(a) if lower else np.triu(a) return [a, rng(rhs_shape, dtype)] - self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, - check_dtypes=True, tol=1e-3) - self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3) + self._CompileAndCheck(jsp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -1132,7 +1124,7 @@ class ScipyLinalgTest(jtu.JaxTestCase): l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower, unit_diagonal=unit_diagonal) - self.assertAllClose(np_ans, ans, check_dtypes=True, + self.assertAllClose(np_ans, ans, rtol={np.float32: 1e-4, np.float64: 1e-11}) @parameterized.named_parameters(jtu.cases_from_list( @@ -1228,15 +1220,13 @@ class ScipyLinalgTest(jtu.JaxTestCase): osp_fun = lambda a: osp.linalg.expm(a) jsp_fun = lambda a: jsp.linalg.expm(a) - self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, - check_dtypes=True) - self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker) + self._CompileAndCheck(jsp_fun, args_maker) args_maker_triu = lambda: [np.triu(rng((n, n), dtype))] - jsp_fun_triu = lambda a: jsp.linalg.expm(a,upper_triangular=True) - self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu, - check_dtypes=True) - self._CompileAndCheck(jsp_fun_triu, args_maker_triu, check_dtypes=True) + jsp_fun_triu = lambda a: jsp.linalg.expm(a, upper_triangular=True) + self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu) + self._CompileAndCheck(jsp_fun_triu, args_maker_triu) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -1249,9 +1239,8 @@ class ScipyLinalgTest(jtu.JaxTestCase): args_maker_zeros = lambda: [np.zeros((n, n), dtype)] osp_fun = lambda a: osp.linalg.expm(a) jsp_fun = lambda a: jsp.linalg.expm(a) - self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros, - check_dtypes=True) - self._CompileAndCheck(jsp_fun, args_maker_zeros, check_dtypes=True) + self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros) + self._CompileAndCheck(jsp_fun, args_maker_zeros) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs={}_rhs={}_lower={}".format( @@ -1280,7 +1269,7 @@ class ScipyLinalgTest(jtu.JaxTestCase): U = np.triu(rng(lhs_shape, dtype)) return [(U, lower), b] self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve, - args_maker, check_dtypes=True, tol=1e-3) + args_maker, tol=1e-3) if __name__ == "__main__": diff --git a/tests/loops_test.py b/tests/loops_test.py index 452524b6b..07c138729 100644 --- a/tests/loops_test.py +++ b/tests/loops_test.py @@ -35,7 +35,7 @@ class LoopsTest(jtu.JaxTestCase): with loops.Scope() as s: s.x = r + 1 return s.x - self.assertAllClose(4.0, f_op(3.), check_dtypes=True) + self.assertAllClose(4.0, f_op(3.)) def test_loop_empty(self): def f_op(r): @@ -44,7 +44,7 @@ class LoopsTest(jtu.JaxTestCase): pass return r - self.assertAllClose(3.0, f_op(3.), check_dtypes=True) + self.assertAllClose(3.0, f_op(3.)) def test_loop_1(self): """One loop with one state var, with transforms.""" @@ -56,14 +56,14 @@ class LoopsTest(jtu.JaxTestCase): return s.out def f_expected(inc): return 10 + 5 * inc - self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True) - self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True) - self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True) - self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True) + self.assertAllClose(f_expected(2.), f_op(2.)) + self.assertAllClose(f_expected(2.), api.jit(f_op)(2.)) + self.assertAllClose(5., api.grad(f_op)(2.)) + self.assertAllClose(5., api.grad(f_op)(2.)) inc_batch = np.arange(5, dtype=jnp.float_) self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch], dtype=jnp.float_), - api.vmap(f_op)(inc_batch), check_dtypes=True) + api.vmap(f_op)(inc_batch)) def test_loop_2(self): @@ -77,7 +77,7 @@ class LoopsTest(jtu.JaxTestCase): s.out2 += 1. return (s.out1, s.out2) - self.assertAllClose((10. + 2. * 5, 20. + 1. * 5), f_op(2.), check_dtypes=True) + self.assertAllClose((10. + 2. * 5, 20. + 1. * 5), f_op(2.)) def test_add_vectors(self): @@ -92,7 +92,7 @@ class LoopsTest(jtu.JaxTestCase): x = jnp.array([1., 2., 3.], dtype=jnp.float32) y = jnp.array([4., 5., 6.], dtype=jnp.float32) - self.assertAllClose(jnp.add(x, y), add_vec(x, y), check_dtypes=True) + self.assertAllClose(jnp.add(x, y), add_vec(x, y)) def test_matmul(self): def matmul(x, y): @@ -109,7 +109,7 @@ class LoopsTest(jtu.JaxTestCase): x = jnp.array([[1., 2., 3.]], dtype=jnp.float32) # 1x3 y = jnp.array([[4.], [5.], [6.]], dtype=jnp.float32) # 3x1 - self.assertAllClose(jnp.matmul(x, y), matmul(x, y), check_dtypes=True) + self.assertAllClose(jnp.matmul(x, y), matmul(x, y)) def test_reuse_range(self): """Ranges can be reused, as long as not nested in each other.""" @@ -136,7 +136,7 @@ class LoopsTest(jtu.JaxTestCase): s.out += inc return s.out - self.assertAllClose(10. + 5 * (2. + 6 * 2.), f_op(2.), check_dtypes=True) + self.assertAllClose(10. + 5 * (2. + 6 * 2.), f_op(2.)) def test_example_doc(self): "The example from the module docstring." @@ -169,8 +169,8 @@ class LoopsTest(jtu.JaxTestCase): s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.) return s.arr - self.assertAllClose(f_expected(), f_op_jax(), check_dtypes=True) - self.assertAllClose(f_expected(), f_op_loops(), check_dtypes=True) + self.assertAllClose(f_expected(), f_op_jax()) + self.assertAllClose(f_expected(), f_op_loops()) def test_loop_mutable_used_but_not_changed(self): def f_op(inc): @@ -184,7 +184,7 @@ class LoopsTest(jtu.JaxTestCase): return save_to_other_var - self.assertAllClose(10. + 5 * 2., f_op(2.), check_dtypes=True) + self.assertAllClose(10. + 5 * 2., f_op(2.)) def test_range_locations(self): """Ranges have locations.""" @@ -257,7 +257,7 @@ class LoopsTest(jtu.JaxTestCase): pass return i - self.assertAllClose(4, f_op(4), check_dtypes=True) + self.assertAllClose(4, f_op(4)) def test_error_new_state_in_loop(self): """Error when creating new state in a loop.""" @@ -281,13 +281,13 @@ class LoopsTest(jtu.JaxTestCase): s.out += inc return s.out - self.assertAllClose(16., f_op(0, 4, 4.), check_dtypes=True) + self.assertAllClose(16., f_op(0, 4, 4.)) # Ok to jit, as long as the start and end are static - self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.), check_dtypes=True) + self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.)) with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"): - self.assertAllClose(16., api.jit(f_op)(0, 4, 4.), check_dtypes=True) + self.assertAllClose(16., api.jit(f_op)(0, 4, 4.)) with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"): - self.assertAllClose(16., api.vmap(f_op)(jnp.zeros(10), jnp.ones(10), jnp.array([4.] * 10)), check_dtypes=True) + self.assertAllClose(16., api.vmap(f_op)(jnp.zeros(10), jnp.ones(10), jnp.array([4.] * 10))) def test_cond(self): def f_op(inc): @@ -297,8 +297,8 @@ class LoopsTest(jtu.JaxTestCase): s.out += inc return s.out - self.assertAllClose(10. + 2., f_op(2.), check_dtypes=True) - self.assertAllClose(10., f_op(-2.), check_dtypes=True) + self.assertAllClose(10. + 2., f_op(2.)) + self.assertAllClose(10., f_op(-2.)) def test_cond_state(self): """Conditionals predicated on scope fields.""" @@ -309,8 +309,8 @@ class LoopsTest(jtu.JaxTestCase): s.out *= 2. return s.out - self.assertAllClose(2. * 2., f_op(2.), check_dtypes=True) - self.assertAllClose(-2., f_op(-2.), check_dtypes=True) + self.assertAllClose(2. * 2., f_op(2.)) + self.assertAllClose(-2., f_op(-2.)) def test_cond_nested(self): """Nested conditionals.""" @@ -341,7 +341,7 @@ class LoopsTest(jtu.JaxTestCase): return s.out for init in [-1., 0., 9., 10.]: - self.assertAllClose(f_expected(init), f_op(init), check_dtypes=True) + self.assertAllClose(f_expected(init), f_op(init)) def test_error_cond_using_index_var(self): @@ -373,13 +373,13 @@ class LoopsTest(jtu.JaxTestCase): out += 1. return out - self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True) - self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True) - self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True) + self.assertAllClose(f_expected(2.), f_op(2.)) + self.assertAllClose(f_expected(2.), api.jit(f_op)(2.)) + self.assertAllClose(f_expected(1.), f_op(1.)) init_batch = np.array([1., 2., 3.], dtype=np.float32) self.assertAllClose(np.array([f_expected(init) for init in init_batch], dtype=np.float32), - api.vmap(f_op)(init_batch), check_dtypes=True) + api.vmap(f_op)(init_batch)) def test_error_while_cond_mutation(self): """Disallow mutation in the while conditional.""" diff --git a/tests/masking_test.py b/tests/masking_test.py index f5351eae9..42ec6a666 100644 --- a/tests/masking_test.py +++ b/tests/masking_test.py @@ -601,8 +601,8 @@ class MaskingTest(jtu.JaxTestCase): def test_slice_oob_indexing(self): # https://github.com/google/jax/issues/2245 - self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10], check_dtypes=True) - self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:], check_dtypes=True) + self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10]) + self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:]) if __name__ == '__main__': absltest.main() diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 97e6689ba..87761c1cb 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -51,7 +51,7 @@ class MultiBackendTest(jtu.JaxTestCase): y = npr.uniform(size=(10,10)) z_host = np.matmul(x, y) z = fun(x, y) - self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2) + self.assertAllClose(z, z_host, rtol=1e-2) correct_platform = backend if backend else jtu.device_under_test() self.assertEqual(z.device_buffer.platform(), correct_platform) @@ -73,7 +73,7 @@ class MultiBackendTest(jtu.JaxTestCase): y = npr.uniform(size=(10,10)) z_host = np.matmul(x, y) + np.ones_like(x) z = fun(x, y) - self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2) + self.assertAllClose(z, z_host, rtol=1e-2) correct_platform = outer if outer else jtu.device_under_test() self.assertEqual(z.device_buffer.platform(), correct_platform) diff --git a/tests/nn_test.py b/tests/nn_test.py index 6b5a33061..d4ca83cf1 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -55,7 +55,7 @@ class NNFunctionsTest(jtu.JaxTestCase): def testSoftplusGradInf(self): self.assertAllClose( - 1., jax.grad(nn.softplus)(float('inf')), check_dtypes=True) + 1., jax.grad(nn.softplus)(float('inf'))) def testSoftplusGradNegInf(self): check_grads(nn.softplus, (-float('inf'),), order=1, @@ -91,7 +91,7 @@ class NNFunctionsTest(jtu.JaxTestCase): def testGluValue(self): val = nn.glu(jnp.array([1.0, 0.0])) - self.assertAllClose(val, jnp.array([0.5]), check_dtypes=True) + self.assertAllClose(val, jnp.array([0.5])) @parameterized.parameters(*itertools.product( (jnp.float32, jnp.bfloat16, jnp.float16), @@ -120,33 +120,33 @@ class NNFunctionsTest(jtu.JaxTestCase): expected = jnp.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) - self.assertAllClose(actual, expected, check_dtypes=True) + self.assertAllClose(actual, expected) actual = nn.one_hot(jnp.array([1, 2, 0]), 3) expected = jnp.array([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]]) - self.assertAllClose(actual, expected, check_dtypes=True) + self.assertAllClose(actual, expected) def testOneHotOutOfBound(self): actual = nn.one_hot(jnp.array([-1, 3]), 3) expected = jnp.array([[0., 0., 0.], [0., 0., 0.]]) - self.assertAllClose(actual, expected, check_dtypes=True) + self.assertAllClose(actual, expected) def testOneHotNonArrayInput(self): actual = nn.one_hot([0, 1, 2], 3) expected = jnp.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) - self.assertAllClose(actual, expected, check_dtypes=True) + self.assertAllClose(actual, expected) def testOneHotCustomDtype(self): actual = nn.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_) expected = jnp.array([[True, False, False], [False, True, False], [False, False, True]]) - self.assertAllClose(actual, expected, check_dtypes=True) + self.assertAllClose(actual, expected) InitializerRecord = collections.namedtuple( "InitializerRecord", diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index 3dd07cc44..f025aa3c6 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -39,7 +39,7 @@ class OptimizerTests(jtu.JaxTestCase): def _CheckFuns(self, optimizer, loss, x0, *args): init_fun, update_fun, get_params = optimizer(*args) opt_state = init_fun(x0) - self.assertAllClose(x0, get_params(opt_state), check_dtypes=True) + self.assertAllClose(x0, get_params(opt_state)) opt_state2 = update_fun(0, grad(loss)(x0), opt_state) # doesn't crash self.assertEqual(tree_util.tree_structure(opt_state), tree_util.tree_structure(opt_state2)) @@ -294,7 +294,7 @@ class OptimizerTests(jtu.JaxTestCase): J1 = jacrev(loss, argnums=(0,))(initial_params) J2 = jacfwd(loss, argnums=(0,))(initial_params) - self.assertAllClose(J1, J2, check_dtypes=True, rtol=1e-6) + self.assertAllClose(J1, J2, rtol=1e-6) def testUnpackPackRoundTrip(self): opt_init, _, _ = optimizers.momentum(0.1, mass=0.9) diff --git a/tests/parallel_test.py b/tests/parallel_test.py index 0bce8cd3e..f6d4b6c8d 100644 --- a/tests/parallel_test.py +++ b/tests/parallel_test.py @@ -87,7 +87,7 @@ class PapplyTest(jtu.JaxTestCase): t = np.ones((5, 3)) ans = soft_pmap(*_papply(fun))(t) expected = fun(t) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testLogSoftmax(self): raise SkipTest("test doesn't pass yet") # TODO(frostig) @@ -113,7 +113,7 @@ class PapplyTest(jtu.JaxTestCase): pfun, axis_name = _papply(jnp.add) ans = soft_pmap(pfun, axis_name)(x, x) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testAddBroadcasting(self): raise SkipTest("test doesn't pass yet") # TODO(frostig) @@ -126,7 +126,7 @@ class PapplyTest(jtu.JaxTestCase): pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testMakeJaxprPapplyComposition(self): raise SkipTest( # TODO(mattjj) @@ -142,7 +142,7 @@ class ParallelizeTest(jtu.JaxTestCase): def dedup(self, arr, expected_rank): if arr.ndim == expected_rank + 1: for i in range(arr.shape[0] - 1): - self.assertAllClose(arr[i], arr[i + 1], check_dtypes=True) + self.assertAllClose(arr[i], arr[i + 1]) return arr[0] else: assert arr.ndim == expected_rank diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 205e49a39..010b06cb3 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -217,7 +217,7 @@ class PmapTest(jtu.JaxTestCase): f_expected = np.broadcast_to(x, mesh_shape) f_ans = f(x, y) - self.assertAllClose(f_ans, f_expected, check_dtypes=True) + self.assertAllClose(f_ans, f_expected) self.assertIsInstance(f_ans, pxla.ShardedDeviceArray) # the output is actually replicated (has the same values in each device buffer) # but out_axes is implicitly 0, so we shouldn't have replication in the @@ -226,7 +226,7 @@ class PmapTest(jtu.JaxTestCase): g_expected = np.broadcast_to(x - np.sum(y, 0, keepdims=True), shape) g_ans = g(x, y) - self.assertAllClose(g_ans, g_expected, check_dtypes=True) + self.assertAllClose(g_ans, g_expected) self.assertIsInstance(g_ans, pxla.ShardedDeviceArray) self.assertEqual(g_ans.sharding_spec.replication_factor, 1) @@ -302,7 +302,7 @@ class PmapTest(jtu.JaxTestCase): ans = grad(lambda x: jnp.sum(splitjvp(x)))(x) expected = grad(fun)(x) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testTwoArgsGrad(self): def f(x, y): @@ -352,7 +352,7 @@ class PmapTest(jtu.JaxTestCase): ans = grad(lambda x: jnp.sum(test_fun(x)))(x) expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x) - self.assertAllClose(ans, expected, check_dtypes=True, atol=1e-3) + self.assertAllClose(ans, expected, atol=1e-3) def testShardedDeviceArrays(self): f = lambda x: 2 * x @@ -466,11 +466,11 @@ class PmapTest(jtu.JaxTestCase): expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 1) expected = x - expected_psum ans = f1(x) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) expected = x - expected_psum + 1. ans = f2(x) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) shape = (replicas // 2, 2, 4) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) @@ -482,7 +482,7 @@ class PmapTest(jtu.JaxTestCase): expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0) expected = x - expected_psum ans = f3(x) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testAxisGroups(self): axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2)) @@ -568,7 +568,7 @@ class PmapTest(jtu.JaxTestCase): lambda x: lax.ppermute(x, "i", zip(range(num_devices), perm)), "i") result = f(jnp.arange(num_devices, dtype=jnp.float32)) expected = jnp.asarray(perm, dtype=jnp.float32) - self.assertAllClose(result, expected, check_dtypes=True) + self.assertAllClose(result, expected) @jtu.skip_on_devices("cpu", "gpu") def testRule30(self): @@ -911,7 +911,7 @@ class PmapTest(jtu.JaxTestCase): arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs) r = pmap(lambda x: x + 1)(arr) - self.assertAllClose(r, arr + 1, check_dtypes=True) + self.assertAllClose(r, arr + 1) self.assertEqual(len(r.device_buffers), 6) @ignore_soft_pmap_warning() @@ -1114,8 +1114,7 @@ class PmapTest(jtu.JaxTestCase): vals = list(range(500)) ndevices = xla_bridge.device_count() self.assertAllClose(f(jnp.array([vals] * ndevices)), - jnp.array([sum(vals)] * ndevices), - check_dtypes=True) + jnp.array([sum(vals)] * ndevices)) def testPostProcessMap(self): # code from https://github.com/google/jax/issues/2787 @@ -1221,7 +1220,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase): x = np.arange(prod(shape), dtype=np.float32).reshape(shape) expected = x - np.sum(x, 0) ans = f(x) - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testOneDevice(self): if xla_bridge.device_count() == 1: @@ -1236,8 +1235,8 @@ class PmapWithDevicesTest(jtu.JaxTestCase): r0 = f0(x) r1 = f1(x) expected = np.expand_dims(np.dot(x.squeeze(), x.squeeze().T), 0) - self.assertAllClose(r0, expected, check_dtypes=True, atol=1e-6, rtol=1e-3) - self.assertAllClose(r1, expected, check_dtypes=True, atol=1e-6, rtol=1e-3) + self.assertAllClose(r0, expected, atol=1e-6, rtol=1e-3) + self.assertAllClose(r1, expected, atol=1e-6, rtol=1e-3) def testNoDevicesError(self): f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[]) @@ -1303,7 +1302,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase): ndevices = xla_bridge.device_count() ans = foo(jnp.ones((ndevices, 1))) expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices * 2 - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testPmapInJit(self): @jit @@ -1316,7 +1315,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase): ndevices = xla_bridge.device_count() ans = foo(jnp.ones((ndevices, 1))) expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices - self.assertAllClose(ans, expected, check_dtypes=True) + self.assertAllClose(ans, expected) def testGradBasic(self): @partial(pmap, axis_name='i', devices=xla_bridge.devices()) diff --git a/tests/random_test.py b/tests/random_test.py index 97c344474..a0239c1e4 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -144,15 +144,15 @@ class LaxRandomTest(jtu.JaxTestCase): if jtu.device_under_test() != "tpu": bits8 = random._random_bits(key, 8, (3,)) expected8 = np.array([216, 115, 43], dtype=np.uint8) - self.assertArraysEqual(bits8, expected8, check_dtypes=True) + self.assertArraysEqual(bits8, expected8) bits16 = random._random_bits(key, 16, (3,)) expected16 = np.array([41682, 1300, 55017], dtype=np.uint16) - self.assertArraysEqual(bits16, expected16, check_dtypes=True) + self.assertArraysEqual(bits16, expected16) bits32 = random._random_bits(key, 32, (3,)) expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32) - self.assertArraysEqual(bits32, expected32, check_dtypes=True) + self.assertArraysEqual(bits32, expected32) bits64 = random._random_bits(key, 64, (3,)) if FLAGS.jax_enable_x64: @@ -160,7 +160,7 @@ class LaxRandomTest(jtu.JaxTestCase): 7882654074788531506], dtype=np.uint64) else: expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32) - self.assertArraysEqual(bits64, expected64, check_dtypes=True) + self.assertArraysEqual(bits64, expected64) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype)} @@ -227,7 +227,7 @@ class LaxRandomTest(jtu.JaxTestCase): with self.assertWarns(FutureWarning): perm2 = crand(key) - self.assertAllClose(perm1, perm2, check_dtypes=True) + self.assertAllClose(perm1, perm2) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1), x, check_dtypes=False) @@ -245,12 +245,11 @@ class LaxRandomTest(jtu.JaxTestCase): perm1 = rand(key) perm2 = crand(key) - self.assertAllClose(perm1, perm2, check_dtypes=True) + self.assertAllClose(perm1, perm2) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False) self.assertArraysAllClose( - x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype), - check_dtypes=True) + x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)) def testPermutationInteger(self): key = random.PRNGKey(0) @@ -261,7 +260,7 @@ class LaxRandomTest(jtu.JaxTestCase): perm1 = rand(key) perm2 = crand(key) - self.assertAllClose(perm1, perm2, check_dtypes=True) + self.assertAllClose(perm1, perm2) self.assertEqual(perm1.dtype, perm2.dtype) self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely! self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False) @@ -380,7 +379,7 @@ class LaxRandomTest(jtu.JaxTestCase): compiled_samples = crand(key, alpha) for samples in [uncompiled_samples, compiled_samples]: - self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype), check_dtypes=True) + self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype)) alpha_sum = sum(alpha) for i, a in enumerate(alpha): self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf) @@ -435,7 +434,7 @@ class LaxRandomTest(jtu.JaxTestCase): pdf = scipy.stats.gamma.pdf(z, alpha) expected_grad = -cdf_dot / pdf - self.assertAllClose(actual_grad, expected_grad, check_dtypes=True, + self.assertAllClose(actual_grad, expected_grad, rtol=2e-2 if jtu.device_under_test() == "tpu" else 5e-4) def testGammaGradType(self): @@ -670,28 +669,24 @@ class LaxRandomTest(jtu.JaxTestCase): random.randint(k, (3, 3), 0, 8), np.array([[7, 2, 6], [2, 1, 0], - [6, 7, 7]], dtype='int64'), - check_dtypes=True) + [6, 7, 7]], dtype='int64')) else: self.assertAllClose( random.randint(k, (3, 3), 0, 8), np.array([[2, 1, 3], [6, 1, 5], - [6, 3, 4]], dtype='int32'), - check_dtypes=True) + [6, 3, 4]], dtype='int32')) self.assertAllClose( random.split(k, 4), np.array([[2285895361, 1501764800], [1518642379, 4090693311], [ 433833334, 4221794875], - [ 839183663, 3740430601]], dtype='uint32'), - check_dtypes=True) + [ 839183663, 3740430601]], dtype='uint32')) self.assertAllClose( random.fold_in(k, 4), - np.array([2285895361, 433833334], dtype='uint32'), - check_dtypes=True) + np.array([2285895361, 433833334], dtype='uint32')) if __name__ == "__main__": diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 81a05fc63..fbfe0f1cd 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -103,11 +103,9 @@ class NdimageTest(jtu.JaxTestCase): if dtype in float_dtypes: epsilon = max([dtypes.finfo(dtypes.canonicalize_dtype(d)).eps for d in [dtype, coords_dtype]]) - self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=100*epsilon, - check_dtypes=True) + self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=100*epsilon) else: - self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0, - check_dtypes=True) + self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0) def testMapCoordinatesErrors(self): x = onp.arange(5.0) @@ -137,7 +135,7 @@ class NdimageTest(jtu.JaxTestCase): lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order) osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order) - self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, check_dtypes=True) + self._CheckAgainstNumpy(lsp_op, osp_op, args_maker) def testContinuousGradients(self): # regression test for https://github.com/google/jax/issues/3024 diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 3d78413e1..852a75135 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -64,7 +64,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-8} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jsp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format( @@ -87,7 +87,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-14} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jsp_fun, args_maker) if __name__ == "__main__": diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 9f13cefcc..93a0249a2 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -61,8 +61,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, - rtol={onp.float64: 1e-14}) + self._CompileAndCheck(lax_fun, args_maker, rtol={onp.float64: 1e-14}) @genNamedParametersNArgs(3, jtu.rand_default) def testPoissonPmf(self, rng_factory, shapes, dtypes): @@ -80,7 +79,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_default) def testBernoulliLogPmf(self, rng_factory, shapes, dtypes): @@ -97,7 +96,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_default) def testGeomLogPmf(self, rng_factory, shapes, dtypes): @@ -114,7 +113,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5, jtu.rand_positive) def testBetaLogPdf(self, rng_factory, shapes, dtypes): @@ -128,7 +127,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, + self._CompileAndCheck(lax_fun, args_maker, rtol={onp.float32: 2e-3, onp.float64: 1e-4}) @genNamedParametersNArgs(3, jtu.rand_default) @@ -145,7 +144,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(2, jtu.rand_positive) def testDirichletLogPdf(self, rng_factory, shapes, dtypes): @@ -162,7 +161,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_positive) def testExponLogPdf(self, rng_factory, shapes, dtypes): @@ -176,7 +175,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4, jtu.rand_positive) def testGammaLogPdf(self, rng_factory, shapes, dtypes): @@ -190,7 +189,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_positive) def testLaplaceLogPdf(self, rng_factory, shapes, dtypes): @@ -206,7 +205,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_default) def testLaplaceCdf(self, rng_factory, shapes, dtypes): @@ -222,7 +221,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol={onp.float32: 1e-5, onp.float64: 1e-6}) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1, jtu.rand_default) def testLogisticCdf(self, rng_factory, shapes, dtypes): @@ -235,7 +234,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1, jtu.rand_default) def testLogisticLogpdf(self, rng_factory, shapes, dtypes): @@ -248,7 +247,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1, jtu.rand_default) def testLogisticPpf(self, rng_factory, shapes, dtypes): @@ -261,7 +260,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1, jtu.rand_default) def testLogisticSf(self, rng_factory, shapes, dtypes): @@ -274,7 +273,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_default) def testNormLogPdf(self, rng_factory, shapes, dtypes): @@ -290,7 +289,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_default) @@ -307,7 +306,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_default) @@ -324,7 +323,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_default) @@ -341,9 +340,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [q, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, rtol=3e-4) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(4, jtu.rand_positive) @@ -358,7 +356,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4, jtu.rand_default) @@ -375,7 +373,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3, jtu.rand_default) @@ -390,7 +388,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker) def testIssue972(self): self.assertAllClose( @@ -455,9 +453,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf, lsp_stats.multivariate_normal.logpdf, - args_maker, check_dtypes=True, tol=1e-3) + args_maker, tol=1e-3) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, - check_dtypes=True, rtol=1e-4, atol=1e-4) + rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/vectorize_test.py b/tests/vectorize_test.py index f3c05c206..028005785 100644 --- a/tests/vectorize_test.py +++ b/tests/vectorize_test.py @@ -127,18 +127,18 @@ class VectorizeTest(jtu.JaxTestCase): b, a = center(jnp.arange(3)) self.assertEqual(a.shape, (3,)) self.assertEqual(b.shape, ()) - self.assertAllClose(1.0, b, False) + self.assertAllClose(1.0, b, check_dtypes=False) X = jnp.arange(12).reshape((3, 4)) b, a = center(X, axis=1) self.assertEqual(a.shape, (3, 4)) self.assertEqual(b.shape, (3,)) - self.assertAllClose(jnp.mean(X, axis=1), b, True) + self.assertAllClose(jnp.mean(X, axis=1), b) b, a = center(X, axis=0) self.assertEqual(a.shape, (3, 4)) self.assertEqual(b.shape, (4,)) - self.assertAllClose(jnp.mean(X, axis=0), b, True) + self.assertAllClose(jnp.mean(X, axis=0), b) if __name__ == "__main__":