Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)

* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
This commit is contained in:
Peter Hawkins 2020-06-01 17:19:23 -04:00 committed by GitHub
parent 49a441f745
commit fffdb2daa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 559 additions and 615 deletions

View File

@ -32,7 +32,7 @@ neg = np.negative
sign = np.sign
floor = np.floor
ceil = np.ceil
round = lambda x: np.trunc(x + np.copysign(.5, x))
round = lambda x: np.trunc(x + np.copysign(.5, x)).astype(x.dtype)
nextafter = np.nextafter
is_finite = np.isfinite
@ -47,7 +47,7 @@ cos = np.cos
atan2 = np.arctan2
sqrt = np.sqrt
rsqrt = lambda x: 1. / np.sqrt(x)
rsqrt = lambda x: np.ones_like(x) / np.sqrt(x)
square = np.square
reciprocal = np.reciprocal
tan = np.tan
@ -60,16 +60,17 @@ asinh = np.arcsinh
acosh = np.arccosh
atanh = np.arctanh
betainc = scipy.special.betainc
lgamma = scipy.special.gammaln
digamma = scipy.special.digamma
def betainc(a, b, x): return scipy.special.betainc(a, b, x).astype(x.dtype)
def lgamma(x): return scipy.special.gammaln(x).astype(x.dtype)
def digamma(x): return scipy.special.digamma(x).astype(x.dtype)
igamma = scipy.special.gammainc
igammac = scipy.special.gammaincc
erf = scipy.special.erf
erfc = scipy.special.erfc
erf_inv = scipy.special.erfinv
bessel_i0e = scipy.special.i0e
bessel_i1e = scipy.special.i1e
def erf(x): return scipy.special.erf(x).astype(x.dtype)
def erfc(x): return scipy.special.erfc(x).astype(x.dtype)
def erf_inv(x): return scipy.special.erfinv(x).astype(x.dtype)
def bessel_i0e(x): return scipy.special.i0e(x).astype(x.dtype)
def bessel_i1e(x): return scipy.special.i1e(x).astype(x.dtype)
real = np.real
imag = np.imag
@ -150,7 +151,7 @@ def bitcast_convert_type(operand, dtype):
return np.asarray(operand).view(dtype)
def clamp(min, operand, max):
return np.clip(operand, np.clip(min, None, max), max)
return np.clip(operand, np.clip(min, None, max), max).astype(operand.dtype)
def concatenate(operands, dimension):
return np.concatenate(operands, axis=dimension)
@ -295,8 +296,6 @@ def sort_key_val(keys, values, dimension=-1):
idxs[dimension] = np.argsort(keys, axis=dimension)
return keys[tuple(idxs)], values[tuple(idxs)]
# TODO untake
### conv util
def _conv(lhs, rhs, window_strides, pads):

View File

@ -719,13 +719,14 @@ class JaxTestCase(parameterized.TestCase):
def rng(self):
return self._rng
def assertArraysEqual(self, x, y, check_dtypes):
def assertArraysEqual(self, x, y, *, check_dtypes=True):
"""Assert that x and y arrays are exactly equal."""
if check_dtypes:
self.assertDtypesMatch(x, y)
np.testing.assert_equal(x, y)
def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None):
def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None,
rtol=None):
"""Assert that x and y are close (up to numerical tolerances)."""
self.assertEqual(x.shape, y.shape)
atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol))
@ -740,18 +741,20 @@ class JaxTestCase(parameterized.TestCase):
if FLAGS.jax_enable_x64:
self.assertEqual(_dtype(x), _dtype(y))
def assertAllClose(self, x, y, check_dtypes, atol=None, rtol=None):
def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None):
"""Assert that x and y, either arrays or nested tuples/lists, are close."""
if isinstance(x, dict):
self.assertIsInstance(y, dict)
self.assertEqual(set(x.keys()), set(y.keys()))
for k in x.keys():
self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol)
self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol,
rtol=rtol)
elif is_sequence(x) and not hasattr(x, '__array__'):
self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
self.assertEqual(len(x), len(y))
for x_elt, y_elt in zip(x, y):
self.assertAllClose(x_elt, y_elt, check_dtypes, atol=atol, rtol=rtol)
self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol,
rtol=rtol)
elif hasattr(x, '__array__') or np.isscalar(x):
self.assertTrue(hasattr(y, '__array__') or np.isscalar(y))
if check_dtypes:
@ -772,7 +775,7 @@ class JaxTestCase(parameterized.TestCase):
self.assertMultiLineEqual(expected_clean, what_clean,
msg="Found\n{}\nExpecting\n{}".format(what, expected))
def _CompileAndCheck(self, fun, args_maker, check_dtypes,
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True,
rtol=None, atol=None):
"""Helper method for running JAX compilation and allclose assertions."""
args = args_maker()
@ -802,8 +805,10 @@ class JaxTestCase(parameterized.TestCase):
python_should_be_executing = False
compiled_ans = cfun(*args)
self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol)
self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)
self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes,
atol=atol, rtol=rtol)
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
atol=atol, rtol=rtol)
args = args_maker()
@ -813,10 +818,11 @@ class JaxTestCase(parameterized.TestCase):
python_should_be_executing = False
compiled_ans = cfun(*args)
self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
atol=atol, rtol=rtol)
def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
check_dtypes=False, tol=None):
check_dtypes=True, tol=None):
args = args_maker()
lax_ans = lax_op(*args)
numpy_ans = numpy_reference_op(*args)

View File

@ -462,7 +462,7 @@ class APITest(jtu.JaxTestCase):
def test_grad_and_aux_basic(self):
g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
self.assertAllClose(g, grad(lambda x: x**3)(3.), check_dtypes=True)
self.assertAllClose(g, grad(lambda x: x**3)(3.))
self.assertAllClose(aux, [9.], check_dtypes=False)
def test_grad_and_aux_nested(self):
@ -557,7 +557,7 @@ class APITest(jtu.JaxTestCase):
res2 = api.jit(inner)(5.)
return res1 + res2
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)), check_dtypes=True)
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)))
def test_complex_grad_raises_error(self):
@ -596,7 +596,7 @@ class APITest(jtu.JaxTestCase):
ans = jacrev(f)(zs)
expected = grad(f)(zs)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def test_complex_input_jacfwd_raises_error(self):
self.assertRaises(TypeError, lambda: jacfwd(lambda x: jnp.sin(x))(1 + 2j))
@ -980,7 +980,7 @@ class APITest(jtu.JaxTestCase):
futures = [executor.submit(partial(f, x)) for x in xs]
ys = [f.result() for f in futures]
for x, y in zip(xs, ys):
self.assertAllClose(x, y, check_dtypes=True)
self.assertAllClose(x, y)
def test_concurrent_jit(self):
@jit
@ -992,7 +992,7 @@ class APITest(jtu.JaxTestCase):
futures = [executor.submit(partial(f, x)) for x in xs]
ys = [f.result() for f in futures]
for x, y in zip(xs, ys):
self.assertAllClose(x * 2 - 3., y, check_dtypes=True)
self.assertAllClose(x * 2 - 3., y)
def test_dtype_warning(self):
# cf. issue #1230
@ -1054,7 +1054,7 @@ class APITest(jtu.JaxTestCase):
out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
self.assertAllClose(out1, out2, check_dtypes=True)
self.assertAllClose(out1, out2)
def test_vmap_in_axes_tree_prefix_error(self):
# https://github.com/google/jax/issues/795
@ -2007,11 +2007,10 @@ class CustomJVPTest(jtu.JaxTestCase):
f.defjvp(f_jvp)
x = 3.
self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(f(x), jnp.sin(x))
self.assertAllClose(api.jvp(f, (x,), (1.,)),
(jnp.sin(x), 2 * jnp.cos(x)),
check_dtypes=True)
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x), check_dtypes=True)
(jnp.sin(x), 2 * jnp.cos(x)))
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
def test_invariance(self):
@api.custom_jvp
@ -2052,8 +2051,8 @@ class CustomJVPTest(jtu.JaxTestCase):
return f(x), 3 * g
f.defjvp(f_jvp)
x = 2.
self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(f(-x), jnp.cos(-x), check_dtypes=True)
self.assertAllClose(f(x), jnp.sin(x))
self.assertAllClose(f(-x), jnp.cos(-x))
self.assertAllClose(api.jvp(f, (x,), (1.,)),
(jnp.sin(x), 2.),
check_dtypes=False)
@ -2079,29 +2078,24 @@ class CustomJVPTest(jtu.JaxTestCase):
xx = jnp.arange(6.).reshape(2, 3)
# vmap of f
self.assertAllClose(api.vmap(f)(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx), check_dtypes=True)
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
# vmap of jvp of f
self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x),
(jnp.sin(x), 2 * jnp.cos(x) * x),
check_dtypes=True)
(jnp.sin(x), 2 * jnp.cos(x) * x))
self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx),
(jnp.sin(xx), 2 * jnp.cos(xx) * xx),
check_dtypes=True)
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
# jvp of vmap of f
self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)),
(jnp.sin(x), 2 * jnp.cos(x) * x),
check_dtypes=True)
(jnp.sin(x), 2 * jnp.cos(x) * x))
self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)),
(jnp.sin(xx), 2 * jnp.cos(xx) * xx),
check_dtypes=True)
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
# vmap of jvp of vmap of f
self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx),
(jnp.sin(xx), 2 * jnp.cos(xx) * xx),
check_dtypes=True)
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
def test_jit(self):
@api.custom_jvp
@ -2116,8 +2110,8 @@ class CustomJVPTest(jtu.JaxTestCase):
x = 3.
# jit
self.assertAllClose(api.jit(f)(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
# jit of jvp
self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x),
@ -2139,7 +2133,7 @@ class CustomJVPTest(jtu.JaxTestCase):
return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']}
f.defjvp(f_jvp)
x = {'a': 3.}
self.assertAllClose(f(x)['b'], jnp.sin(x['a']), check_dtypes=True)
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
self.assertAllClose(api.jvp(f, (x,), (x,)),
({'b': jnp.sin(x['a'])},
{'b': 2 * jnp.cos(x['a']) * x['a']}),
@ -2491,11 +2485,10 @@ class CustomVJPTest(jtu.JaxTestCase):
f.defvjp(f_fwd, f_rev)
x = 3.
self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x), check_dtypes=True)
self.assertAllClose(f(x), jnp.sin(x))
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
self.assertAllClose(api.value_and_grad(f)(x),
(jnp.sin(x), 2 * jnp.cos(x)),
check_dtypes=True)
(jnp.sin(x), 2 * jnp.cos(x)))
def test_invariance(self):
@api.custom_vjp
@ -2539,8 +2532,8 @@ class CustomVJPTest(jtu.JaxTestCase):
return (3 * g,)
f.defvjp(f_fwd, f_rev)
x = 2.
self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(f(-x), jnp.cos(-x), check_dtypes=True)
self.assertAllClose(f(x), jnp.sin(x))
self.assertAllClose(f(-x), jnp.cos(-x))
self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.),
check_dtypes=False)
self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.),
@ -2562,33 +2555,26 @@ class CustomVJPTest(jtu.JaxTestCase):
xx = jnp.arange(6.).reshape(2, 3)
# vmap of f
self.assertAllClose(api.vmap(f)(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx), check_dtypes=True)
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
# vmap of grad of f
self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x),
check_dtypes=True)
self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x))
self.assertAllClose(api.vmap(api.value_and_grad(f))(x),
(jnp.sin(x), 2 * jnp.cos(x)),
check_dtypes=True)
self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx),
check_dtypes=True)
(jnp.sin(x), 2 * jnp.cos(x)))
self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx))
self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx),
(jnp.sin(xx), 2 * jnp.cos(xx)),
check_dtypes=True)
(jnp.sin(xx), 2 * jnp.cos(xx)))
# grad of vmap of f
self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x),
2 * jnp.cos(x),
check_dtypes=True)
2 * jnp.cos(x))
self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx),
2 * jnp.cos(xx),
check_dtypes=True)
2 * jnp.cos(xx))
# vmap of grad of vmap of f
self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx),
2 * jnp.cos(xx),
check_dtypes=True)
2 * jnp.cos(xx))
def test_jit(self):
@api.custom_vjp
@ -2603,8 +2589,8 @@ class CustomVJPTest(jtu.JaxTestCase):
x = 3.
# jit
self.assertAllClose(api.jit(f)(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x), check_dtypes=True)
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
# jit of grad
self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x),
@ -2625,10 +2611,9 @@ class CustomVJPTest(jtu.JaxTestCase):
return ({'a': 2 * cos_x * g['b']},)
f.defvjp(f_fwd, f_bwd)
x = {'a': 3.}
self.assertAllClose(f(x)['b'], jnp.sin(x['a']), check_dtypes=True)
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
self.assertAllClose(api.grad(lambda x: f(x)['b'])(x),
{'a': 2 * jnp.cos(x['a'])},
check_dtypes=True)
{'a': 2 * jnp.cos(x['a'])})
def test_jvp_error(self):
@api.custom_vjp
@ -2680,7 +2665,7 @@ class CustomVJPTest(jtu.JaxTestCase):
ans = api.grad(api.grad(foo))(3.)
expected = -2. * jnp.sin(3.)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def test_initial_style_vmap(self):
@api.custom_vjp
@ -2851,7 +2836,7 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
val_ans, grad_ans = api.value_and_grad(foo)(3.)
self.assertAllClose(val_ans, 9., check_dtypes=False)
self.assertAllClose(grad_ans, 12., check_dtypes=True)
self.assertAllClose(grad_ans, 12.)
def test_defvjp_all_higher_order_revmode(self):
foo_p = Primitive('foo')

View File

@ -69,7 +69,7 @@ class DLPackTest(jtu.JaxTestCase):
x = jnp.array(np)
dlpack = jax.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)
self.assertAllClose(np.astype(x.dtype), y, check_dtypes=True)
self.assertAllClose(np.astype(x.dtype), y)
self.assertRaisesRegex(RuntimeError,
"DLPack tensor may be consumed at most once",
@ -89,7 +89,7 @@ class DLPackTest(jtu.JaxTestCase):
x = x.cuda() if jtu.device_under_test() == "gpu" else x
dlpack = torch.utils.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)
self.assertAllClose(np, y, check_dtypes=True)
self.assertAllClose(np, y)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
@ -104,7 +104,7 @@ class DLPackTest(jtu.JaxTestCase):
x = jnp.array(np)
dlpack = jax.dlpack.to_dlpack(x)
y = torch.utils.dlpack.from_dlpack(dlpack)
self.assertAllClose(np, y.numpy(), check_dtypes=True)
self.assertAllClose(np, y.numpy())
class CudaArrayInterfaceTest(jtu.JaxTestCase):
@ -128,7 +128,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
z = cupy.asarray(y)
self.assertEqual(y.__cuda_array_interface__["data"][0],
z.__cuda_array_interface__["data"][0])
self.assertAllClose(x, cupy.asnumpy(z), check_dtypes=True)
self.assertAllClose(x, cupy.asnumpy(z))
if __name__ == "__main__":

View File

@ -195,7 +195,7 @@ class BatchingTest(jtu.JaxTestCase):
ans = vmap(lambda x: x > 1.0)(x)
expected_ans = x > 1.0
self.assertAllClose(ans, expected_ans, check_dtypes=True)
self.assertAllClose(ans, expected_ans)
def testNpMaximumPerExampleGrad(self):
R = np.random.RandomState(0).randn
@ -226,35 +226,35 @@ class BatchingTest(jtu.JaxTestCase):
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
ans = vmap(fun)(x, y)
expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
x = R(3, 4, 10, 5)
y = R(3, 10, 5, 6)
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
ans = vmap(fun, in_axes=(2, 1))(x, y)
expected = np.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
x = R(3, 4, 5, 10)
y = R(3, 5, 6)
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
ans = vmap(fun, in_axes=(3, None))(x, y)
expected = np.stack([fun(x[..., i], y) for i in range(10)])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
x = R(3, 4, 5)
y = R(3, 5, 10, 6)
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
ans = vmap(fun, in_axes=(None, 2))(x, y)
expected = np.stack([fun(x, y[..., i, :]) for i in range(10)])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
x = R(4)
y = R(4, 10)
fun = lambda x, y: lax.dot_general(x, y, [((0,), (0,)), ((), ())])
ans = vmap(fun, in_axes=(None, 1))(x, y)
expected = np.stack([fun(x, y[..., i]) for i in range(10)])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testDot(self):
# these tests are based on @shoyer's notebook studying gufuncs
@ -348,7 +348,7 @@ class BatchingTest(jtu.JaxTestCase):
ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]]))
expected = jnp.array([True, False])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
@jtu.skip_on_devices("tpu")
def testHessian(self):
@ -417,46 +417,44 @@ class BatchingTest(jtu.JaxTestCase):
v = np.arange(12)[::-1].reshape(3, 4)
sv = vmap(partial(lax.sort, dimension=0), (0,))(v)
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
self.assertAllClose(sv, v[:, ::-1])
sv = vmap(partial(lax.sort, dimension=-1), (0,))(v)
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
self.assertAllClose(sv, v[:, ::-1])
sv = vmap(partial(lax.sort, dimension=0), (1,))(v)
self.assertAllClose(sv, v[::-1, :].T, check_dtypes=True)
self.assertAllClose(sv, v[::-1, :].T)
sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v)
self.assertAllClose(sv, v[::-1, :], check_dtypes=True)
self.assertAllClose(sv, v[::-1, :])
def testSortKeyVal(self):
k = np.arange(12)[::-1].reshape(3, 4)
v = np.random.RandomState(0).permutation(12).reshape(3, 4)
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
self.assertAllClose(sk, k[:, ::-1])
self.assertAllClose(sv, v[:, ::-1])
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
self.assertAllClose(sk, k[::-1, :], check_dtypes=True)
self.assertAllClose(sv, v[::-1, :], check_dtypes=True)
self.assertAllClose(sk, k[::-1, :])
self.assertAllClose(sv, v[::-1, :])
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
self.assertAllClose(sk, k[:, ::-1])
self.assertAllClose(sv, v[:, ::-1])
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
self.assertAllClose(sk, k[:, ::-1])
self.assertAllClose(sv, v[:, ::-1])
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v)
self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)),
check_dtypes=True)
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)))
self.assertAllClose(sv, v[:, ::-1])
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0])
self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)),
check_dtypes=True)
self.assertAllClose(sk, k[:, ::-1])
self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)))
def testConvGeneralDilated(self):
W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32)
@ -474,7 +472,7 @@ class BatchingTest(jtu.JaxTestCase):
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
per_example_direct = f(W, X)
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
self.assertAllClose(per_example, per_example_direct)
# Test gradients.
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
@ -484,7 +482,7 @@ class BatchingTest(jtu.JaxTestCase):
per_example_direct += [
jnp.reshape(g, (1,) + g.shape)]
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
self.assertAllClose(per_example, per_example_direct, check_dtypes=True,
self.assertAllClose(per_example, per_example_direct,
rtol=2e-2)
def testConvGeneralDilatedBatchNotMajor(self):
@ -503,7 +501,7 @@ class BatchingTest(jtu.JaxTestCase):
(5, 5, 21, 4))
per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)),
(5, 21, 5, 1)))
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
self.assertAllClose(per_example, per_example_direct)
@parameterized.named_parameters(
{"testcase_name": "_op={}".format(name), "op": op, "unit": unit}
@ -526,7 +524,7 @@ class BatchingTest(jtu.JaxTestCase):
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
per_example_direct = f(W, X)
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
self.assertAllClose(per_example, per_example_direct)
# Test gradients.
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
@ -536,7 +534,7 @@ class BatchingTest(jtu.JaxTestCase):
per_example_direct += [
jnp.reshape(g, (1,) + g.shape)]
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
self.assertAllClose(per_example, per_example_direct, check_dtypes=True,
self.assertAllClose(per_example, per_example_direct,
rtol=5e-2)
def testSumPool(self):
@ -557,7 +555,7 @@ class BatchingTest(jtu.JaxTestCase):
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
per_example_direct = f(W, X)
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
self.assertAllClose(per_example, per_example_direct)
# Test gradients.
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
@ -567,14 +565,13 @@ class BatchingTest(jtu.JaxTestCase):
per_example_direct += [
jnp.reshape(g, (1,) + g.shape)]
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
self.assertAllClose(per_example, per_example_direct, check_dtypes=True,
self.assertAllClose(per_example, per_example_direct,
rtol=3e-2)
def testCumProd(self):
x = jnp.arange(9).reshape(3, 3) + 1
y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y,
check_dtypes=True)
self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y)
def testSelect(self):
pred = np.array([True, False])
@ -582,7 +579,7 @@ class BatchingTest(jtu.JaxTestCase):
on_false = np.array([2, 3])
ans = vmap(lax.select)(pred, on_true, on_false)
expected = np.array([0, 3])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
pred = np.array([False, True])
on_true = np.array([0, 1])
@ -590,28 +587,28 @@ class BatchingTest(jtu.JaxTestCase):
ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
expected = np.array([[2, 3],
[0, 1]])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
pred = True
on_true = np.array([0, 1], np.float32)
on_false = np.array(3, np.float32)
ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
expected = np.array([0, 1], np.float32)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
pred = np.array([False, True])
on_true = np.array([0, 1], np.float32)
on_false = np.array(3, np.float32)
ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
expected = np.array([3, 1], np.float32)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
pred = np.array([False, True])
on_true = np.array([2], np.float32)
on_false = np.array([[3, 4]], np.float32)
ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
expected = np.array([[3, 2]], np.float32)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testLaxLinalgCholesky(self):
a = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32)
@ -637,17 +634,17 @@ class BatchingTest(jtu.JaxTestCase):
ans = vmap(lax_linalg.triangular_solve, in_axes=(1, 2))(a, b)
expected = np.stack(
[lax_linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
ans = vmap(lax_linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
expected = np.stack(
[lax_linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
ans = vmap(lax_linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0])
expected = np.stack(
[lax_linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
@ -987,7 +984,7 @@ class BatchingTest(jtu.JaxTestCase):
g = jax.jit(jax.pmap(f))
ans = g(index1=np.asarray([1]), index2=np.asarray([2]))
expected = g(np.asarray([1]), np.asarray([2]))
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
if __name__ == '__main__':

View File

@ -69,11 +69,11 @@ class CallbackTest(jtu.JaxTestCase):
return x * 2
x = jnp.array([2.0, 4.0])
self.assertAllClose(f(x), jnp.array([4.0, 8.0]), True)
self.assertAllClose(f(x), jnp.array([4.0, 8.0]))
self.assertAllClose(
rewrite(f, {lax.mul_p: lambda x, y: x + y})(x),
jnp.array([4.0, 6.0]), True)
jnp.array([4.0, 6.0]))
def testRewriteJIT(self):
def f(x):
@ -83,11 +83,11 @@ class CallbackTest(jtu.JaxTestCase):
return g(x)
x = jnp.array([2.0, 4.0])
self.assertAllClose(f(x), jnp.array([4.0, 8.0]), True)
self.assertAllClose(f(x), jnp.array([4.0, 8.0]))
self.assertAllClose(
rewrite(f, {lax.mul_p: lambda x, y: x + y})(x),
jnp.array([4.0, 6.0]), True)
jnp.array([4.0, 6.0]))
if __name__ == "__main__":
absltest.main()

View File

@ -109,7 +109,7 @@ class FftTest(jtu.JaxTestCase):
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fn, args_maker)
# Test gradient for differentiable types.
if dtype in (float_dtypes if real and not inverse else inexact_dtypes):
# TODO(skye): can we be more precise?
@ -169,7 +169,7 @@ class FftTest(jtu.JaxTestCase):
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
@ -228,7 +228,7 @@ class FftTest(jtu.JaxTestCase):
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
@ -279,7 +279,7 @@ class FftTest(jtu.JaxTestCase):
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fn, args_maker)
# Test gradient for differentiable types.
if dtype in inexact_dtypes:
tol = 0.15 # TODO(skye): can we be more precise?
@ -323,7 +323,7 @@ class FftTest(jtu.JaxTestCase):
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fn, args_maker)
# Test gradient for differentiable types.
if dtype in inexact_dtypes:
tol = 0.15 # TODO(skye): can we be more precise?
@ -362,7 +362,7 @@ class FftTest(jtu.JaxTestCase):
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes)
np_fn = lambda arg: np.fft.fftshift(arg, axes=axes)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "dtype={}_axes={}".format(
@ -377,7 +377,7 @@ class FftTest(jtu.JaxTestCase):
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes)
np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
if __name__ == "__main__":
absltest.main()

View File

@ -148,7 +148,7 @@ class HostCallbackTest(jtu.JaxTestCase):
self.assertEqual("", testing_stream.output)
with hcb.outfeed_receiver():
self.assertAllClose((5. * 2.) ** 2, fun1(5.), check_dtypes=True)
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
@ -232,7 +232,7 @@ what: x3
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
res = jit_fun1(5.)
self.assertAllClose(6. * 5., res, check_dtypes=True)
self.assertAllClose(6. * 5., res)
assertMultiLineStrippedEqual(self, """
what: here
10.00""", testing_stream.output)
@ -258,7 +258,7 @@ what: here
self.assertEqual("", testing_stream.output)
with hcb.outfeed_receiver():
self.assertAllClose(5, api.jit(func)(5), check_dtypes=True)
self.assertAllClose(5, api.jit(func)(5))
assertMultiLineStrippedEqual(self, """
42""", testing_stream.output)
testing_stream.reset()
@ -512,7 +512,7 @@ where: 3
if with_jit:
func = api.jit(func)
res = func(1)
self.assertAllClose(jnp.array([1, 2, 3]), res, check_dtypes=True)
self.assertAllClose(jnp.array([1, 2, 3]), res)
assertMultiLineStrippedEqual(self, """
where: 1
1
@ -564,7 +564,7 @@ where: 10
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
res = jit_fun1(args)
# self.assertAllClose(args, res, check_dtypes=True)
# self.assertAllClose(args, res)
def test_jit_large(self):
arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))

View File

@ -44,7 +44,7 @@ class InfeedTest(jax.test_util.JaxTestCase):
device = jax.local_devices()[0]
device.transfer_to_infeed((y,))
device.transfer_to_infeed((z,))
self.assertAllClose(f(x), x + y + z, check_dtypes=True)
self.assertAllClose(f(x), x + y + z)
def testInfeedThenOutfeed(self):
@jax.jit
@ -64,7 +64,7 @@ class InfeedTest(jax.test_util.JaxTestCase):
out, = device.transfer_from_outfeed(
xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent())
execution.join()
self.assertAllClose(out, y + onp.float32(1), check_dtypes=True)
self.assertAllClose(out, y + onp.float32(1))
def testInfeedThenOutfeedInALoop(self):
def doubler(_, token):
@ -87,7 +87,7 @@ class InfeedTest(jax.test_util.JaxTestCase):
device.transfer_to_infeed((x,))
y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,))
.with_major_to_minor_layout_if_absent())
self.assertAllClose(y, x * onp.float32(2), check_dtypes=True)
self.assertAllClose(y, x * onp.float32(2))
execution.join()

View File

@ -159,11 +159,9 @@ class JetTest(jtu.JaxTestCase):
atol = 1e-4
rtol = 1e-4
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=True)
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=True)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol)
@jtu.skip_on_devices("tpu")

View File

@ -404,7 +404,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(count(3), 3)
self.assertEqual(count(4), 6)
for args_maker in [lambda: [2], lambda: [3], lambda: [4]]:
self._CompileAndCheck(count, args_maker, True)
self._CompileAndCheck(count, args_maker)
def testForiLoopClosure(self):
def count(num):
@ -1356,13 +1356,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
xs = jnp.arange(10)
expected = xs ** 2
actual = lax.map(f, xs)
self.assertAllClose(actual, expected, check_dtypes=True)
self.assertAllClose(actual, expected)
def testMapEmpty(self):
# https://github.com/google/jax/issues/2412
ans = lax.map(lambda x: x * x, jnp.array([]))
expected = jnp.array([])
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testCaching(self):
def cond(x):
@ -1598,7 +1598,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
actual = api.jit(linear_solve)(a, b)
expected = jnp.linalg.solve(a, b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
def test_custom_root_with_custom_linear_solve(self):
@ -1620,11 +1620,11 @@ class LaxControlFlowTest(jtu.JaxTestCase):
actual = linear_solve(jnp.dot(a, a.T), b)
expected = jnp.linalg.solve(jnp.dot(a, a.T), b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
actual = api.jit(linear_solve)(jnp.dot(a, a.T), b)
expected = jnp.linalg.solve(jnp.dot(a, a.T), b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
jtu.check_grads(lambda x, y: linear_solve(jnp.dot(x, x.T), y),
(a, b), order=2, rtol={jnp.float32: 1e-2})
@ -1670,12 +1670,12 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = jnp.linalg.solve(a, b)
actual = api.jit(linear_solve)(a, b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
c = rng.randn(3, 2)
expected = jnp.linalg.solve(a, c)
actual = api.vmap(linear_solve, (None, 1), 1)(a, c)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_zeros(self):
@ -1723,7 +1723,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
b = rng.randn(2)
expected = jnp.linalg.solve(jnp.exp(a), jnp.cos(b))
actual = build_and_solve(a, b)
self.assertAllClose(expected, actual, atol=1e-5, check_dtypes=True)
self.assertAllClose(expected, actual, atol=1e-5)
jtu.check_grads(build_and_solve, (a, b), atol=1e-5, order=2,
rtol={jnp.float32: 6e-2, jnp.float64: 2e-3})
@ -1749,10 +1749,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = jnp.linalg.solve(np.asarray(posify(a)), b)
actual = positive_definite_solve(posify(a), b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
actual = api.jit(positive_definite_solve)(posify(a), b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
# numerical gradients are only well defined if ``a`` is guaranteed to be
# positive definite.
@ -1798,7 +1798,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = jnp.linalg.solve(a, b)
actual = linear_solve(a, b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3)

View File

@ -33,7 +33,7 @@ class EinsumTest(jtu.JaxTestCase):
def _check(self, s, *ops):
a = np.einsum(s, *ops)
b = jnp.einsum(s, *ops, precision=lax.Precision.HIGHEST)
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4, check_dtypes=True)
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
def test_three_operands_1(self):
r = self.rng()

View File

@ -402,7 +402,7 @@ class IndexingTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
fun = lambda x: x[indexer]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters({
"testcase_name":
@ -503,7 +503,7 @@ class IndexingTest(jtu.JaxTestCase):
return x[indexer]
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
@ -553,7 +553,7 @@ class IndexingTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype), indexer]
fun = lambda x, idx: jnp.asarray(x)[idx]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
@ -635,7 +635,7 @@ class IndexingTest(jtu.JaxTestCase):
idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
return jnp.asarray(x)[idx]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
def testAdvancedIndexingManually(self):
x = onp.random.RandomState(0).randn(3, 4, 5)
@ -647,7 +647,7 @@ class IndexingTest(jtu.JaxTestCase):
a1 = op(x, index_array)
a2 = cop(x, index_array)
self.assertAllClose(a1, a2, check_dtypes=True)
self.assertAllClose(a1, a2)
op = lambda x, index_array: x[..., index_array, :, index_array, None]
cop = api.jit(op)
@ -655,7 +655,7 @@ class IndexingTest(jtu.JaxTestCase):
a1 = op(x, index_array)
a2 = cop(x, index_array)
self.assertAllClose(a1, a2, check_dtypes=True)
self.assertAllClose(a1, a2)
op = lambda x, index_array: x[index_array, ..., index_array[:, None], None]
cop = api.jit(op)
@ -663,7 +663,7 @@ class IndexingTest(jtu.JaxTestCase):
a1 = op(x, index_array)
a2 = cop(x, index_array)
self.assertAllClose(a1, a2, check_dtypes=True)
self.assertAllClose(a1, a2)
def testUnpacking(self):
@ -676,7 +676,7 @@ class IndexingTest(jtu.JaxTestCase):
a1 = foo(onp.arange(3))
a2 = cfoo(onp.arange(3))
self.assertAllClose(a1, a2, check_dtypes=True)
self.assertAllClose(a1, a2)
def testBooleanIndexingArray1D(self):
idx = onp.array([True, True, False])
@ -739,8 +739,8 @@ class IndexingTest(jtu.JaxTestCase):
primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i)))
expected = onp.broadcast_to(
onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4))
self.assertAllClose(expected, primals, check_dtypes=True)
self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)
self.assertAllClose(expected, primals)
self.assertAllClose(onp.zeros_like(x), tangents)
def testTrivialGatherIsntGenerated(self):
# https://github.com/google/jax/issues/1621
@ -791,7 +791,7 @@ class IndexingTest(jtu.JaxTestCase):
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
array = jnp.ones(5)
self.assertAllClose(array, array[:10], check_dtypes=True)
self.assertAllClose(array, array[:10])
def _broadcastable_shapes(shape):
@ -877,8 +877,8 @@ class IndexedUpdateTest(jtu.JaxTestCase):
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
else:
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker)
self._CompileAndCheck(jax_fn, args_maker)
@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
@ -904,8 +904,8 @@ class IndexedUpdateTest(jtu.JaxTestCase):
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
else:
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker)
self._CompileAndCheck(jax_fn, args_maker)
@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
@ -931,8 +931,8 @@ class IndexedUpdateTest(jtu.JaxTestCase):
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
else:
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker)
self._CompileAndCheck(jax_fn, args_maker)
@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(

View File

@ -91,7 +91,8 @@ OpRecord = collections.namedtuple(
"test_name", "check_dtypes", "tolerance", "inexact"])
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name=None, check_dtypes=True, tolerance=None, inexact=False):
test_name=None, check_dtypes=True,
tolerance=None, inexact=False):
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name, check_dtypes, tolerance, inexact)
@ -497,7 +498,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
() in shapes)
empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes)
self._CompileAndCheck(
fun, args_maker, check_dtypes=True, #not scalar_arg and not empty_shape,
fun, args_maker, #not scalar_arg and not empty_shape,
atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
@ -525,7 +526,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
() in shapes)
empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes)
self._CompileAndCheck(
fun, args_maker, check_dtypes=True, # not scalar_arg and not empty_shape,
fun, args_maker, # not scalar_arg and not empty_shape,
atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
@ -577,7 +578,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
@ -616,7 +617,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol,
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
@ -646,8 +647,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
@ -661,7 +662,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_fun = lambda x: jnp.count_nonzero(x, axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
@ -733,13 +734,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
except ValueError as e:
if str(e) == "All-NaN slice encountered":
self.skipTest("JAX doesn't support checking for all-NaN slices")
else:
raise
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": rec.test_name.capitalize(), "name": rec.name,
@ -755,8 +756,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = partial(np_op, axis=0)
jnp_fun = partial(jnp_op, axis=0)
args_maker = lambda: [np.zeros((2, 0))]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
@ -793,9 +794,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15}
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
jtu.tolerance(rhs_dtype, tol_spec))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol,
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
@ -830,9 +831,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x
y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y
return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype))
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker,
tol=tol)
self._CompileAndCheck(jnp.dot, args_maker, check_dtypes=True, atol=tol,
self._CompileAndCheck(jnp.dot, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
@ -866,10 +867,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np.complex128: 1e-12}
if jtu.device_under_test() == "tpu":
tol[np.float32] = tol[np.complex64] = 4e-2
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker,
check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp.matmul, args_maker, check_dtypes=True, atol=tol,
rtol=tol)
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol)
self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
@ -902,9 +901,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np.complex64: 1e-3, np.complex128: 1e-12}
if jtu.device_under_test() == "tpu":
tol[np.float32] = tol[np.complex64] = 2e-1
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
def testTensordotErrors(self):
a = np.random.random((3, 2, 2))
@ -941,8 +940,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert)
np_fun = lambda e, t: np.isin(e, t, invert=invert)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -960,8 +959,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
jnp_fun = lambda e, t: jnp.in1d(e, t, invert=invert)
np_fun = lambda e, t: np.in1d(e, t, invert=invert)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -1015,7 +1014,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
# TODO(phawkins): the promotion behavior changed in Numpy 1.17.
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
def testClipError(self):
with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"):
@ -1044,15 +1043,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testOperatorRound(self):
self.assertAllClose(round(np.float32(7.532), 1),
round(jnp.float32(7.5), 1), check_dtypes=True)
round(jnp.float32(7.5), 1))
self.assertAllClose(round(np.float32(1.234), 2),
round(jnp.float32(1.234), 2), check_dtypes=True)
round(jnp.float32(1.234), 2))
self.assertAllClose(round(np.float32(1.234)),
round(jnp.float32(1.234)), check_dtypes=False)
self.assertAllClose(round(np.float32(7.532), 1),
round(jnp.array(7.5, jnp.float32), 1), check_dtypes=True)
round(jnp.array(7.5, jnp.float32), 1))
self.assertAllClose(round(np.float32(1.234), 2),
round(jnp.array(1.234, jnp.float32), 2), check_dtypes=True)
round(jnp.array(1.234, jnp.float32), 2))
self.assertAllClose(round(np.float32(1.234)),
round(jnp.array(1.234, jnp.float32)),
check_dtypes=False)
@ -1098,7 +1097,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape=[{}]_reps={}".format(
@ -1117,7 +1116,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
@ -1128,7 +1127,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testExtract(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)]
self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
@ -1150,7 +1149,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = partial(np.compress, axis=axis)
jnp_fun = partial(jnp.compress, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_condition=array[{}]_axis={}".format(
@ -1166,7 +1165,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [np.array(condition), rng(shape, dtype)]
np_fun = partial(np.compress, axis=axis)
jnp_fun = partial(jnp.compress, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
@ -1193,8 +1192,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def args_maker():
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
@ -1221,8 +1220,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def args_maker():
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape=[{}]_axis={}_repeats={}".format(
@ -1240,8 +1239,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_ind={}_inv={}_count={}".format(
@ -1260,7 +1259,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda x: np.unique(x, return_index, return_inverse, return_counts)
jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
def testIssue1233(self):
'''
@ -1273,10 +1272,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
lax_ans = jnp.repeat(m, repeats, axis)
numpy_ans = np.repeat(m, repeats, axis)
self.assertAllClose(lax_ans, numpy_ans, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol)
jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
m = jnp.array([1,2,3,4,5,6])
args_maker = lambda: [m]
@ -1313,8 +1312,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
attempt_sideeffect(np_input)
attempt_sideeffect(jnp_input)
self.assertAllClose(np_input, expected_np_input_after_call, check_dtypes=True)
self.assertAllClose(jnp_input, expected_jnp_input_after_call, check_dtypes=True)
self.assertAllClose(np_input, expected_np_input_after_call)
self.assertAllClose(jnp_input, expected_jnp_input_after_call)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format(
@ -1339,7 +1338,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_fun = partial(jnp_op, mode=mode, precision=precision)
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format(
@ -1365,9 +1364,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
tol_thresholds = {dtypes.bfloat16: 4e-2}
tol = max(jtu.tolerance(dtype, tol_thresholds),
jtu.tolerance(out_dtype, tol_thresholds))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -1411,8 +1410,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype)
jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_shape={}_k={}".format(
@ -1428,8 +1427,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda arg: getattr(np, op)(arg, k=k)
jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_ndim={}_n={}".format(ndim, n),
@ -1453,8 +1452,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda arg: np.diag(arg, k)
jnp_fun = lambda arg: jnp.diag(arg, k)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format(
@ -1472,8 +1471,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2)
jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}".format(np.dtype(dtype).name, n),
@ -1484,8 +1483,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda: np.identity(n, dtype)
jnp_fun = lambda: jnp.identity(n, dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x1={}_x2={}_x1_rng={}".format(
@ -1515,8 +1514,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_fun = lambda x1, x2: jnp.ldexp(x1, x2)
args_maker = lambda: [x1_rng(x1_shape, x1_dtype),
x2_rng(x2_shape, np.int32)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_rng_factory={}".format(
@ -1539,8 +1538,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: np.frexp(x)
jnp_fun = lambda x: jnp.frexp(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -1565,8 +1564,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
return np.trace(arg, offset, axis1, axis2, out_dtype)
jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_v={}_side={}".format(
@ -1585,8 +1584,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [jnp.sort(rng(ashape, dtype)), rng(vshape, dtype)]
np_fun = lambda a, v: np.searchsorted(a, v, side=side)
jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_bins={}_right={}_reverse={}".format(
@ -1607,8 +1606,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
np_fun = lambda x, bins: np.digitize(x, bins, right=right)
jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
@ -1629,7 +1628,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(partial(np.stack, axis=axis))
jnp_fun = partial(jnp.stack, axis=axis)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -1651,7 +1650,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(getattr(np, op))
jnp_fun = getattr(jnp, op)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outdtype={}".format(
@ -1667,8 +1666,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype)
jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype)
args_maker = lambda: [rng((), fill_value_dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(
jtu.cases_from_list(
@ -1684,8 +1683,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def args_maker(): return []
np_op = partial(np_op, shape, dtype)
jnp_op = partial(jnp_op, shape, dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testOnesWithInvalidShape(self):
with self.assertRaises(TypeError):
@ -1708,8 +1707,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x, fill_value: np.full_like(x, fill_value, dtype=out_dtype)
jnp_fun = lambda x, fill_value: jnp.full_like(x, fill_value, dtype=out_dtype)
args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_{}sections".format(
@ -1725,8 +1724,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: np.split(x, num_sections, axis=axis)
jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testSplitTypeError(self):
# If we pass an ndarray for indices_or_sections -> no error
@ -1774,7 +1773,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# linspace() compares poorly to numpy when using bfloat16
if dtype != jnp.bfloat16:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
self._CompileAndCheck(jnp_fun, args_maker,
atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
@ -1808,7 +1807,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if dtype != jnp.bfloat16:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_{}sections".format(
@ -1833,8 +1832,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: fn(np, axis)(x, num_sections)
jnp_fun = lambda x: fn(jnp, axis)(x, num_sections)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_order={}".format(
@ -1860,8 +1859,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: np.reshape(x, out_shape, order=order)
jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}".format(
@ -1880,8 +1879,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: np.reshape(x, out_shape)
jnp_fun = lambda x: x.reshape(*out_shape)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_expanddim={!r}".format(
@ -1898,11 +1897,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: np.expand_dims(x, dim)
jnp_fun = lambda x: jnp.expand_dims(x, dim)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
if isinstance(dim, tuple) and numpy_version < (1, 18, 0):
raise SkipTest("support for multiple axes added in NumPy 1.18.0")
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_axes=({},{})".format(
@ -1918,8 +1917,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: np.swapaxes(x, ax1, ax2)
jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_axis={!r}".format(
@ -1940,8 +1939,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda x: np.squeeze(x, ax)
jnp_fun = lambda x: jnp.squeeze(x, ax)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format(
@ -2004,8 +2003,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
else:
np_fun = partial(np.array, dtype=dtype)
jnp_fun = jnp.array
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testArrayUnsupportedDtypeError(self):
with self.assertRaisesRegex(TypeError,
@ -2042,8 +2041,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
ans = jnp.array(bytearray(b'\x2a'))
self.assertAllClose(
ans,
np.array([0x2a], dtype=np.uint8),
check_dtypes=True)
np.array([0x2a], dtype=np.uint8))
def testIsClose(self):
c_isclose = api.jit(jnp.isclose)
@ -2195,8 +2193,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.flip(x, axis)
np_op = lambda x: np.flip(x, axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
@ -2210,8 +2208,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.flipud(x)
np_op = lambda x: np.flipud(x)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2226,8 +2224,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.fliplr(x)
np_op = lambda x: np.fliplr(x)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2248,15 +2246,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.rot90(x, k, axes)
np_op = lambda x: np.rot90(x, k, axes)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
# TODO(mattjj): test infix operator overrides
def testRavel(self):
rng = np.random.RandomState(0)
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
self._CompileAndCheck(lambda x: x.ravel(), args_maker)
@parameterized.parameters(
(0, (2, 1, 3)),
@ -2265,12 +2263,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
([0, 1, 2], (2, 2)),
([[[0, 1], [2, 3]]], (2, 2)))
def testUnravelIndex(self, flat_index, shape):
self._CheckAgainstNumpy(
np.unravel_index,
jnp.unravel_index,
lambda: (flat_index, shape),
check_dtypes=True
)
self._CheckAgainstNumpy(np.unravel_index, jnp.unravel_index,
lambda: (flat_index, shape))
def testUnravelIndexOOB(self):
self.assertEqual(jnp.unravel_index(2, (2,)), (1,))
@ -2282,8 +2276,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
np_op = lambda x: np.asarray(x).astype(jnp.int32)
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int32)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_dtype={}".format(
@ -2305,8 +2299,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_op = lambda x: jnp.asarray(x).view(dtype)
# Above may produce signaling nans; ignore warnings from invalid values.
with np.errstate(invalid='ignore'):
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testPathologicalFloats(self):
args_maker = lambda: [np.array([
@ -2325,8 +2319,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_op = lambda x: np.asarray(x).view('float32').view('uint32')
jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32')
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
# TODO(mattjj): test other ndarray-like method overrides
@ -2340,29 +2334,29 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# from https://github.com/google/jax/issues/145
expected = np.arange(0.0, 1.0, 0.1, dtype=jnp.float_)
ans = jnp.arange(0.0, 1.0, 0.1)
self.assertAllClose(expected, ans, check_dtypes=True)
self.assertAllClose(expected, ans)
def testSortManually(self):
# manual tests for sort are nice because we don't have to worry about ties.
# lax.sort is tested combinatorially.
ans = jnp.sort(np.array([16, 15, 23, 42, 8, 4]))
expected = np.array([4, 8, 15, 16, 23, 42])
self.assertAllClose(expected, ans, check_dtypes=True)
self.assertAllClose(expected, ans)
a = np.array([[1, 4], [3, 1]])
ans = jnp.sort(a, axis=None)
expected = np.array([1, 1, 3, 4])
self.assertAllClose(expected, ans, check_dtypes=True)
self.assertAllClose(expected, ans)
a = np.array([[1, 4], [3, 1]])
ans = jnp.sort(a) # last axis
expected = np.array([[1, 4], [1, 3]])
self.assertAllClose(expected, ans, check_dtypes=True)
self.assertAllClose(expected, ans)
a = np.array([[1, 4], [3, 1]])
ans = jnp.sort(a, axis=0)
expected = np.array([[1, 1], [3, 4]])
self.assertAllClose(expected, ans, check_dtypes=True)
self.assertAllClose(expected, ans)
def testArgsortManually(self):
x = np.array([16, 15, 23, 42, 8, 4])
@ -2394,8 +2388,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [np.random.randint(50, size=(5 ,5))]
jnp_op = lambda x: jnp.msort(x)
np_op = lambda x: np.msort(x)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_shifts={}_axis={}".format(
@ -2420,8 +2414,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype), np.array(shifts)]
jnp_op = partial(jnp.roll, axis=axis)
np_op = partial(np.roll, axis=axis)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_start={}".format(
@ -2439,8 +2433,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.rollaxis, axis=axis, start=start)
np_op = partial(np.rollaxis, axis=axis, start=start)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_bitorder={}".format(
@ -2459,8 +2453,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
np_op = partial(np.packbits, axis=axis, bitorder=bitorder)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_bitorder={}_count={}".format(
@ -2479,8 +2473,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_index={}_axis={}_mode={}".format(
@ -2506,8 +2500,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rng_indices = jtu.rand_int(self.rng(), -5, 5)
jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode)
np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_ishape={}_axis={}".format(
@ -2541,8 +2535,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if hasattr(np, "take_along_axis"):
np_op = lambda x, i: np.take_along_axis(x, i, axis=axis)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}_increasing={}".format(
@ -2606,9 +2600,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)
for shape, dtype in zip(shapes, dtypes)]
self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker,
check_dtypes=True)
self._CompileAndCheck(jnp.ix_, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker)
self._CompileAndCheck(jnp.ix_, args_maker)
@parameterized.named_parameters(
jtu.cases_from_list(
@ -2630,8 +2623,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
dtype=dtype, sparse=sparse)
jnp_fun = partial(jnp.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -2681,7 +2674,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jtu.tolerance(q_dtype, tol_spec))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2714,7 +2707,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
tol = jtu.tolerance(a_dtype, tol_spec)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2745,9 +2738,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = self._GetArgsMaker(rng_factory(self.rng()), shapes, dtypes)
def np_fun(cond, x, y):
return _promote_like_jnp(partial(np.where, cond))(x, y)
self._CheckAgainstNumpy(np_fun, jnp.where, args_maker,
check_dtypes=True)
self._CompileAndCheck(jnp.where, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp.where, args_maker)
self._CompileAndCheck(jnp.where, args_maker)
def testWhereScalarPromotion(self):
x = jnp.where(jnp.array([True, False]), 3,
@ -2782,7 +2774,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np.asarray(default, dtype=dtype))
self._CheckAgainstNumpy(np_fun, jnp.select, args_maker,
check_dtypes=False)
self._CompileAndCheck(jnp.select, args_maker, check_dtypes=True,
self._CompileAndCheck(jnp.select, args_maker,
rtol={np.float64: 1e-7, np.complex128: 1e-7})
@ -2823,7 +2815,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
a = np.arange(6) + 1
ans = jnp.reshape(a, (3, 2), order='F')
expected = np.reshape(a, (3, 2), order='F')
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_dtype={}".format(op, pytype.__name__),
@ -2836,8 +2828,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda arg: getattr(np, op)(arg).astype(dtype)
jnp_fun = lambda arg: getattr(jnp, op)(arg)
args_maker = lambda: [pytype(2)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{
@ -2865,7 +2857,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_fun = partial(jnp.bincount, minlength=minlength, length=length)
if length is not None:
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
if length is None:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
@ -2906,32 +2898,29 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
])))
def testBlock(self, input):
args_maker = lambda: [input]
self._CheckAgainstNumpy(np.block, jnp.block, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp.block, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np.block, jnp.block, args_maker)
self._CompileAndCheck(jnp.block, args_maker)
def testLongLong(self):
self.assertAllClose(np.int64(7), api.jit(lambda x: x)(np.longlong(7)),
check_dtypes=True)
self.assertAllClose(np.int64(7), api.jit(lambda x: x)(np.longlong(7)))
def testArange(self):
# test cases inspired by dask tests at
# https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92
self.assertAllClose(jnp.arange(77),
np.arange(77, dtype=jnp.int_), check_dtypes=True)
np.arange(77, dtype=jnp.int_))
self.assertAllClose(jnp.arange(2, 13),
np.arange(2, 13, dtype=jnp.int_), check_dtypes=True)
np.arange(2, 13, dtype=jnp.int_))
self.assertAllClose(jnp.arange(4, 21, 9),
np.arange(4, 21, 9, dtype=jnp.int_), check_dtypes=True)
np.arange(4, 21, 9, dtype=jnp.int_))
self.assertAllClose(jnp.arange(53, 5, -3),
np.arange(53, 5, -3, dtype=jnp.int_),
check_dtypes=True)
np.arange(53, 5, -3, dtype=jnp.int_))
self.assertAllClose(jnp.arange(77, dtype=float),
np.arange(77, dtype=float), check_dtypes=True)
np.arange(77, dtype=float))
self.assertAllClose(jnp.arange(2, 13, dtype=int),
np.arange(2, 13, dtype=int), check_dtypes=True)
np.arange(2, 13, dtype=int))
self.assertAllClose(jnp.arange(0, 1, -0.5),
np.arange(0, 1, -0.5, dtype=jnp.float_),
check_dtypes=True)
np.arange(0, 1, -0.5, dtype=jnp.float_))
self.assertRaises(TypeError, lambda: jnp.arange())
@ -2976,8 +2965,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# argument.
return lax.tie_in(y, 7.)
self.assertAllClose(np.zeros(3,), api.grad(f)(np.ones(3,)),
check_dtypes=True)
self.assertAllClose(np.zeros(3,), api.grad(f)(np.ones(3,)))
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because this
# is a numerical stability issue that should be solved with a custom jvp rule
@ -2985,8 +2973,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# def testIssue777(self):
# x = jnp.linspace(-200, 0, 4, dtype=np.float32)
# f = api.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x))))
# self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32),
# check_dtypes=True)
# self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32))
@parameterized.named_parameters(
jtu.cases_from_list(
@ -3017,7 +3004,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
expected = np_op(x)
actual = jnp_op(x)
tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7})
self.assertAllClose(expected, actual, check_dtypes=True, atol=tol,
self.assertAllClose(expected, actual, atol=tol,
rtol=tol)
def testIssue883(self):
@ -3067,9 +3054,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
else:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol,
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
atol=tol)
@parameterized.named_parameters(
@ -3100,9 +3087,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
else:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol,
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
atol=tol)
@parameterized.named_parameters(
@ -3128,7 +3115,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
tol = jtu.join_tolerance(tol, jtu.tolerance(dtype))
self._CheckAgainstNumpy(
np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol,
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
def testIssue967(self):
@ -3157,7 +3144,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(
np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=1e-2 if jtu.device_under_test() == "tpu" else None)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype),
@ -3184,8 +3171,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
(None if begin_dtype is None else rng(begin_shape, begin_dtype))]
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testEDiff1dWithDtypeCast(self):
rng = jtu.rand_default(self.rng())
@ -3196,8 +3183,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)]
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(
jtu.cases_from_list(
@ -3216,7 +3203,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
[dtype] * len(shapes))
np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse)
jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(
jtu.cases_from_list(
@ -3272,7 +3259,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rng = rng_factory(self.rng())
endpoints = rng((2,), dtype)
out = jnp.linspace(*endpoints, 10, dtype=dtype)
self.assertAllClose(out[[0, -1]], endpoints, check_dtypes=True, rtol=0, atol=0)
self.assertAllClose(out[[0, -1]], endpoints, rtol=0, atol=0)
@parameterized.named_parameters(
jtu.cases_from_list(
@ -3454,8 +3441,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32])
np_op = lambda x: np.broadcast_to(x, to_shape)
jnp_op = lambda x: jnp.broadcast_to(x, to_shape)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testBroadcastToIssue1522(self):
self.assertRaisesRegex(
@ -3540,7 +3527,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = lambda y: np.gradient(y, *varargs, axis=axis)
self._CheckAgainstNumpy(
np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
def testZerosShapeErrors(self):
# see https://github.com/google/jax/issues/1822
@ -3556,9 +3543,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testTraceMethod(self):
x = self.rng().randn(3, 4).astype(jnp.float_)
self.assertAllClose(x.trace(), jnp.array(x).trace(), check_dtypes=True)
self.assertAllClose(x.trace(), api.jit(lambda y: y.trace())(x),
check_dtypes=True)
self.assertAllClose(x.trace(), jnp.array(x).trace())
self.assertAllClose(x.trace(), api.jit(lambda y: y.trace())(x))
def testIntegerPowersArePrecise(self):
# See https://github.com/google/jax/pull/3036

View File

@ -141,8 +141,8 @@ class VectorizeTest(jtu.JaxTestCase):
return y
x = jnp.arange(3)
self.assertAllClose(x, f('foo', x), check_dtypes=True)
self.assertAllClose(x, jax.jit(f, 0)('foo', x), check_dtypes=True)
self.assertAllClose(x, f('foo', x))
self.assertAllClose(x, jax.jit(f, 0)('foo', x))
def test_exclude_second(self):
@ -153,8 +153,8 @@ class VectorizeTest(jtu.JaxTestCase):
return x
x = jnp.arange(3)
self.assertAllClose(x, f(x, 'foo'), check_dtypes=True)
self.assertAllClose(x, jax.jit(f, 1)(x, 'foo'), check_dtypes=True)
self.assertAllClose(x, f(x, 'foo'))
self.assertAllClose(x, jax.jit(f, 1)(x, 'foo'))
def test_exclude_errors(self):
with self.assertRaisesRegex(

View File

@ -92,7 +92,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
partial(scipy_cg, M=M, maxiter=1),
partial(lax_cg, M=M, maxiter=1),
args_maker,
check_dtypes=True,
tol=1e-3)
# TODO(shoyer,mattjj): I had to loosen the tolerance for complex64[7,7]
@ -101,14 +100,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
partial(scipy_cg, M=M, maxiter=3),
partial(lax_cg, M=M, maxiter=3),
args_maker,
check_dtypes=True,
tol=3e-3)
self._CheckAgainstNumpy(
np.linalg.solve,
partial(lax_cg, M=M, atol=1e-6),
args_maker,
check_dtypes=True,
tol=2e-2)
@parameterized.named_parameters(jtu.cases_from_list(
@ -125,10 +122,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
expected = np.linalg.solve(posify(a), b)
actual = lax_cg(posify(a), b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
actual = jit(lax_cg)(posify(a), b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
# numerical gradients are only well defined if ``a`` is guaranteed to be
# positive definite.
@ -141,7 +138,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
b = jnp.arange(9.0).reshape((3, 3))
expected = b / 2
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
self.assertAllClose(expected, actual, check_dtypes=True)
self.assertAllClose(expected, actual)
def test_cg_pytree(self):
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}

View File

@ -108,8 +108,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
@ -129,7 +129,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
args = args_maker()
self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
check_dtypes=False)
self._CompileAndCheck(lax_op, args_maker, check_dtypes=True, rtol=1e-5)
self._CompileAndCheck(lax_op, args_maker, rtol=1e-5)
if test_autodiff:
jtu.check_grads(lax_op, args, order=1,
@ -153,14 +153,14 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
tol={onp.float32: 1e-3, onp.float64: 1e-14})
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
def testIssue980(self):
x = onp.full((4,), -1e20, dtype=onp.float32)
self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
lsp_special.expit(x), check_dtypes=True)
lsp_special.expit(x))
def testXlogyShouldReturnZero(self):
self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

View File

@ -187,7 +187,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
op = getattr(lax, op_name)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
@ -222,7 +222,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.convert_element_type(x, to_dtype)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_from_dtype={}_to_dtype={}"
@ -249,7 +249,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_from_dtype={}_to_dtype={}"
@ -284,7 +284,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
shapes = [min_shape, operand_shape, max_shape]
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
self._CompileAndCheck(lax.clamp, args_maker, check_dtypes=True)
self._CompileAndCheck(lax.clamp, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
@ -325,7 +325,7 @@ class LaxTest(jtu.JaxTestCase):
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
op = lambda *args: lax.concatenate(args, dim)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
@ -368,7 +368,7 @@ class LaxTest(jtu.JaxTestCase):
def fun(lhs, rhs):
return lax.conv(lhs, rhs, strides, padding)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -419,7 +419,7 @@ class LaxTest(jtu.JaxTestCase):
return lax.conv_with_general_padding(
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
@ -502,7 +502,7 @@ class LaxTest(jtu.JaxTestCase):
dimension_numbers, feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
# TODO(mattjj): test conv_general_dilated against numpy
@ -512,7 +512,7 @@ class LaxTest(jtu.JaxTestCase):
return [rng((10, 5), onp.float32), rng((5, 7), onp.float32)]
jnp_fun = partial(lax.conv_general_dilated, window_strides=(),
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(jnp_fun, onp.dot, args_maker, tol=.1)
@ -685,8 +685,7 @@ class LaxTest(jtu.JaxTestCase):
def testDot(self, lhs_shape, rhs_shape, dtype, precision, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker,
check_dtypes=True)
self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
@ -798,7 +797,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.broadcast(x, broadcast_sizes)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_broadcast_sizes={}".format(
@ -835,7 +834,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(inshape, dtype)]
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
@ -921,7 +920,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(arg_shape, onp.float32)]
op = lambda x: lax.squeeze(x, dimensions)
numpy_op = lambda x: lax_reference.squeeze(x, dimensions)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
check_grads(op, args_maker(), 2, ["fwd", "rev"], eps=1.)
@ -940,7 +939,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(arg_shape, dtype)]
op = lambda x: lax.reshape(x, out_shape)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}".format(
@ -971,7 +970,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_pads={}"
@ -1018,7 +1017,7 @@ class LaxTest(jtu.JaxTestCase):
return [rng(pred_shape, onp.bool_), rng(arg_shape, arg_dtype),
rng(arg_shape, arg_dtype)]
rng = rng_factory(self.rng())
return self._CompileAndCheck(lax.select, args_maker, check_dtypes=True)
return self._CompileAndCheck(lax.select, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_predshape={}_argshapes={}".format(
@ -1061,7 +1060,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.slice(x, starts, limits, strides)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -1109,7 +1108,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)]
op = lambda x, starts: lax.dynamic_slice(x, starts, size_indices)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
@ -1159,8 +1158,7 @@ class LaxTest(jtu.JaxTestCase):
return [rng(shape, dtype), rng(update_shape, dtype),
onp.array(start_indices)]
self._CompileAndCheck(lax.dynamic_update_slice, args_maker,
check_dtypes=True)
self._CompileAndCheck(lax.dynamic_update_slice, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
@ -1202,7 +1200,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.transpose(x, perm)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_perm={}".format(
@ -1257,13 +1255,13 @@ class LaxTest(jtu.JaxTestCase):
init_val = onp.asarray(init_val, dtype=dtype)
fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims)
args_maker = lambda: [rng(shape, dtype), init_val]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
# we separately test the version that uses a concrete init_val because it
# can hit different code paths
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_dtype={}_padding={}"
@ -1297,7 +1295,7 @@ class LaxTest(jtu.JaxTestCase):
# pylint: disable=cell-var-from-loop
for shape, dims, strides in all_configs:
args_maker = lambda: [rng(shape, dtype), init_val]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
# pylint: enable=cell-var-from-loop
# we separately test the version that uses a concrete init_val because it
@ -1308,7 +1306,7 @@ class LaxTest(jtu.JaxTestCase):
# pylint: disable=cell-var-from-loop
for shape, dims, strides in all_configs:
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
# pylint: enable=cell-var-from-loop
@parameterized.named_parameters(jtu.cases_from_list(
@ -1329,9 +1327,9 @@ class LaxTest(jtu.JaxTestCase):
def testCumulativeReduce(self, op, onp_op, shape, dtype, axis, rng_factory):
rng = rng_factory(self.rng())
fun = partial(op, axis=axis)
onp_fun = partial(onp_op, axis=axis)
onp_fun = partial(onp_op, axis=axis, dtype=dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
self._CheckAgainstNumpy(fun, onp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -1350,7 +1348,7 @@ class LaxTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
fun = lambda x: lax.sort(x, dimension=axis)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
@ -1399,7 +1397,7 @@ class LaxTest(jtu.JaxTestCase):
return keys, values
fun = lambda keys, values: lax.sort_key_val(keys, values, axis)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
@ -1447,12 +1445,12 @@ class LaxTest(jtu.JaxTestCase):
values = self.rng().permutation(flat_values).reshape(shape)
return [values]
def reference_top_k(x):
bcast_idxs = onp.broadcast_to(onp.arange(shape[-1]), shape)
bcast_idxs = onp.broadcast_to(onp.arange(shape[-1], dtype=onp.int32), shape)
sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs)
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
op = lambda vs: lax.top_k(vs, k=k)
self._CheckAgainstNumpy(op, reference_top_k, args_maker)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
@ -1468,7 +1466,7 @@ class LaxTest(jtu.JaxTestCase):
def testBatchMatMul(self, lhs_shape, rhs_shape, dtype, rng_factory):
rng = rng_factory(self.rng())
arg_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CompileAndCheck(lax.batch_matmul, arg_maker, check_dtypes=True)
self._CompileAndCheck(lax.batch_matmul, arg_maker)
def testCollapse(self):
@ -1498,7 +1496,7 @@ class LaxTest(jtu.JaxTestCase):
rand_idxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs)
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
fun = lambda src, idxs: lax.index_take(src, idxs, axes)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
@ -1531,7 +1529,7 @@ class LaxTest(jtu.JaxTestCase):
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
@ -1562,7 +1560,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
rng(update_shape, dtype)]
fun = partial(lax.scatter_add, dimension_numbers=dnums)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
@ -1593,7 +1591,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
rng(update_shape, dtype)]
fun = partial(lax.scatter_min, dimension_numbers=dnums)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
@ -1624,7 +1622,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
rng(update_shape, dtype)]
fun = partial(lax.scatter_max, dimension_numbers=dnums)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
@ -1655,7 +1653,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
rng(update_shape, dtype)]
fun = partial(lax.scatter, dimension_numbers=dnums)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CompileAndCheck(fun, args_maker)
def testIssue831(self):
# Tests the DeviceTuple constant handler
@ -1667,7 +1665,7 @@ class LaxTest(jtu.JaxTestCase):
def testReshapeWithUnusualShapes(self):
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
self.assertAllClose(ans, onp.ones((3, 1), onp.float32))
self.assertRaisesRegex(
TypeError,
@ -1722,9 +1720,9 @@ class LazyConstantTest(jtu.JaxTestCase):
jit_result = api.jit(lambda x: lax.add(x, make_const()))(zero)
# ensure they're all the same
self.assertAllClose(asarray_result, expected, check_dtypes=True)
self.assertAllClose(argument_result, expected, check_dtypes=True)
self.assertAllClose(jit_result, expected, check_dtypes=True)
self.assertAllClose(asarray_result, expected)
self.assertAllClose(argument_result, expected)
self.assertAllClose(jit_result, expected)
# ensure repr doesn't crash
repr(make_const())
@ -2748,11 +2746,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
x = 3.14
ans = api.grad(f)(x)
expected = api.grad(f2)(x, x)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
ans = api.grad(api.grad(f))(x)
expected = api.grad(api.grad(f2))(x, x)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
expected = onp.array(0.0)
@ -2784,8 +2782,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
# N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
return 1 / x
grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv))))))
self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)),
check_dtypes=True)
self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)))
def all_bdims(*shapes):
bdims = (itertools.chain([cast(Optional[int], None)],
@ -2819,7 +2816,7 @@ class LaxVmapTest(jtu.JaxTestCase):
args_slice = args_slicer(args, bdims)
ans = api.vmap(op, bdims)(*args)
expected = onp.stack([op(*args_slice(i)) for i in range(bdim_size)])
self.assertAllClose(ans, expected, check_dtypes=True, rtol=rtol, atol=atol)
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(

View File

@ -71,8 +71,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.skipTest("Unimplemented case for complex Cholesky decomposition.")
self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jnp.linalg.cholesky, args_maker, check_dtypes=True)
tol=1e-3)
self._CompileAndCheck(jnp.linalg.cholesky, args_maker)
if jnp.finfo(dtype).bits == 64:
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)
@ -96,14 +96,13 @@ class NumpyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype)]
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jnp.linalg.det, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, tol=1e-3)
self._CompileAndCheck(jnp.linalg.det, args_maker,
rtol={np.float64: 1e-13, np.complex128: 1e-13})
def testDetOfSingularMatrix(self):
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
self.assertAllClose(np.float32(0), jsp.linalg.det(x), check_dtypes=True)
self.assertAllClose(np.float32(0), jsp.linalg.det(x))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -177,10 +176,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np.linalg.tensorsolve,
jnp.linalg.tensorsolve, args_maker,
check_dtypes=True,
tol={np.float32: 1e-2, np.float64: 1e-3})
self._CompileAndCheck(jnp.linalg.tensorsolve,
args_maker, check_dtypes=True,
args_maker,
rtol={np.float64: 1e-13})
@parameterized.named_parameters(jtu.cases_from_list(
@ -198,8 +196,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jnp.linalg.slogdet, args_maker, check_dtypes=True)
tol=1e-3)
self._CompileAndCheck(jnp.linalg.slogdet, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -221,7 +219,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
mat = jnp.array([np.diag(np.ones([5], dtype=np.float32))*(-.01)] * 2)
args_maker = lambda: [mat]
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
check_dtypes=True, tol=1e-3)
tol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
@ -249,7 +247,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.assertTrue(np.all(norm(np.matmul(a, v) - w[..., None, :] * v) < 100))
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
check_dtypes=True, rtol=1e-3)
rtol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
@ -269,7 +267,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
a, = args_maker()
w1, _ = jnp.linalg.eig(a)
w2 = jnp.linalg.eigvals(a)
self.assertAllClose(w1, w2, check_dtypes=True)
self.assertAllClose(w1, w2)
@jtu.skip_on_devices("gpu", "tpu")
def testEigvalsInf(self):
@ -328,7 +326,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.assertTrue(norm(np.matmul(a, v) - w * v) < tol)
self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker,
check_dtypes=True, rtol=1e-3)
rtol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
@ -349,7 +347,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
a = (a + np.conj(a.T)) / 2
return [a]
self._CheckAgainstNumpy(np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker,
check_dtypes=True, tol=1e-3)
tol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -478,9 +476,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
np_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
self._CheckAgainstNumpy(np_fn, np_fn, args_maker,
check_dtypes=False, tol=1e-3)
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fn, np_fn, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(np_fn, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format(
@ -533,7 +531,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False), np.asarray(out), atol=1e-4, rtol=1e-4))
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
args_maker, check_dtypes=True)
args_maker)
if not (compute_uv and full_matrices):
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
compute_uv=compute_uv)
@ -698,8 +696,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jnp.linalg.solve, args_maker, check_dtypes=True)
tol=1e-3)
self._CompileAndCheck(jnp.linalg.solve, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -726,8 +724,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
return [a]
self._CheckAgainstNumpy(np.linalg.inv, jnp.linalg.inv, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jnp.linalg.inv, args_maker, check_dtypes=True)
tol=1e-3)
self._CompileAndCheck(jnp.linalg.inv, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -743,8 +741,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
check_dtypes=True, tol=1e-2)
self._CompileAndCheck(jnp.linalg.pinv, args_maker, check_dtypes=True)
tol=1e-2)
self._CompileAndCheck(jnp.linalg.pinv, args_maker)
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
@ -754,15 +752,13 @@ class NumpyLinalgTest(jtu.JaxTestCase):
a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2)
return jnp.linalg.pinv(a)
j = jax.jacobian(f)(jnp.float32(2.))
self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j,
check_dtypes=True)
self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j)
expected = jnp.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
[[[0., 0.], [-1., 0.]], [[0., 0.], [0., -1.]]]],
dtype=jnp.float32)
self.assertAllClose(
expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)),
check_dtypes=True)
expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}".format(
@ -781,9 +777,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n),
partial(jnp.linalg.matrix_power, n=n),
args_maker, check_dtypes=True, tol=tol)
args_maker, tol=tol)
self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker,
check_dtypes=True, rtol=1e-3)
rtol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
@ -825,9 +821,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
tol = {np.float32: 1e-4, np.float64: 1e-10,
np.complex64: 1e-4, np.complex128: 1e-10}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker,
atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
@ -865,7 +860,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
return [lhs, rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
# Disabled because grad is flaky for low-rank inputs.
# TODO:
@ -880,7 +875,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
grad_test_jc = jit(grad(jit(test)))
xc = np.eye(3, dtype=np.complex)
self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)
self.assertAllClose(xc, grad_test_jc(xc))
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testIssue1151(self):
@ -888,8 +883,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)
b = jnp.array(rng.randn(100, 3), dtype=jnp.float32)
x = jnp.linalg.solve(A, b)
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2,
check_dtypes=True)
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2)
jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A, b)
jac1 = jax.jacobian(jnp.linalg.solve, argnums=1)(A, b)
jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A[0], b[0])
@ -926,8 +920,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def testBlockDiag(self, args):
args_maker = lambda: args
self._CheckAgainstNumpy(osp.linalg.block_diag, jsp.linalg.block_diag,
args_maker, check_dtypes=True)
self._CompileAndCheck(jsp.linalg.block_diag, args_maker, check_dtypes=True)
args_maker)
self._CompileAndCheck(jsp.linalg.block_diag, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -943,15 +937,15 @@ class ScipyLinalgTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
x, = args_maker()
p, l, u = jsp.linalg.lu(x)
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True,
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)),
rtol={np.float32: 1e-3, np.float64: 1e-12,
np.complex64: 1e-3, np.complex128: 1e-12})
self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)
self._CompileAndCheck(jsp.linalg.lu, args_maker)
def testLuOfSingularMatrix(self):
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
p, l, u = jsp.linalg.lu(x)
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True)
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -986,9 +980,9 @@ class ScipyLinalgTest(jtu.JaxTestCase):
us = np.stack([out[2] for out in expected])
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args))
self.assertAllClose(ps, actual_ps, check_dtypes=True)
self.assertAllClose(ls, actual_ls, check_dtypes=True)
self.assertAllClose(us, actual_us, check_dtypes=True)
self.assertAllClose(ps, actual_ps)
self.assertAllClose(ls, actual_ls)
self.assertAllClose(us, actual_us)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -1008,9 +1002,9 @@ class ScipyLinalgTest(jtu.JaxTestCase):
u = np.triu(lu)
for i in range(n):
x[[i, piv[i]],] = x[[piv[i], i],]
self.assertAllClose(x, np.matmul(l, u), check_dtypes=True, rtol=1e-3,
self.assertAllClose(x, np.matmul(l, u), rtol=1e-3,
atol=1e-3)
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -1040,9 +1034,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
lu, piv = osp.linalg.lu_factor(a)
return [lu, piv, rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
self._CompileAndCheck(jsp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -1081,9 +1074,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
a = np.tril(a) if lower else np.triu(a)
return [a, rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
self._CompileAndCheck(jsp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -1132,7 +1124,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower,
unit_diagonal=unit_diagonal)
self.assertAllClose(np_ans, ans, check_dtypes=True,
self.assertAllClose(np_ans, ans,
rtol={np.float32: 1e-4, np.float64: 1e-11})
@parameterized.named_parameters(jtu.cases_from_list(
@ -1228,15 +1220,13 @@ class ScipyLinalgTest(jtu.JaxTestCase):
osp_fun = lambda a: osp.linalg.expm(a)
jsp_fun = lambda a: jsp.linalg.expm(a)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
check_dtypes=True)
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
self._CompileAndCheck(jsp_fun, args_maker)
args_maker_triu = lambda: [np.triu(rng((n, n), dtype))]
jsp_fun_triu = lambda a: jsp.linalg.expm(a,upper_triangular=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu,
check_dtypes=True)
self._CompileAndCheck(jsp_fun_triu, args_maker_triu, check_dtypes=True)
jsp_fun_triu = lambda a: jsp.linalg.expm(a, upper_triangular=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu)
self._CompileAndCheck(jsp_fun_triu, args_maker_triu)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -1249,9 +1239,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
args_maker_zeros = lambda: [np.zeros((n, n), dtype)]
osp_fun = lambda a: osp.linalg.expm(a)
jsp_fun = lambda a: jsp.linalg.expm(a)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros,
check_dtypes=True)
self._CompileAndCheck(jsp_fun, args_maker_zeros, check_dtypes=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros)
self._CompileAndCheck(jsp_fun, args_maker_zeros)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs={}_rhs={}_lower={}".format(
@ -1280,7 +1269,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
U = np.triu(rng(lhs_shape, dtype))
return [(U, lower), b]
self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve,
args_maker, check_dtypes=True, tol=1e-3)
args_maker, tol=1e-3)
if __name__ == "__main__":

View File

@ -35,7 +35,7 @@ class LoopsTest(jtu.JaxTestCase):
with loops.Scope() as s:
s.x = r + 1
return s.x
self.assertAllClose(4.0, f_op(3.), check_dtypes=True)
self.assertAllClose(4.0, f_op(3.))
def test_loop_empty(self):
def f_op(r):
@ -44,7 +44,7 @@ class LoopsTest(jtu.JaxTestCase):
pass
return r
self.assertAllClose(3.0, f_op(3.), check_dtypes=True)
self.assertAllClose(3.0, f_op(3.))
def test_loop_1(self):
"""One loop with one state var, with transforms."""
@ -56,14 +56,14 @@ class LoopsTest(jtu.JaxTestCase):
return s.out
def f_expected(inc):
return 10 + 5 * inc
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
self.assertAllClose(f_expected(2.), f_op(2.))
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.))
self.assertAllClose(5., api.grad(f_op)(2.))
self.assertAllClose(5., api.grad(f_op)(2.))
inc_batch = np.arange(5, dtype=jnp.float_)
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch],
dtype=jnp.float_),
api.vmap(f_op)(inc_batch), check_dtypes=True)
api.vmap(f_op)(inc_batch))
def test_loop_2(self):
@ -77,7 +77,7 @@ class LoopsTest(jtu.JaxTestCase):
s.out2 += 1.
return (s.out1, s.out2)
self.assertAllClose((10. + 2. * 5, 20. + 1. * 5), f_op(2.), check_dtypes=True)
self.assertAllClose((10. + 2. * 5, 20. + 1. * 5), f_op(2.))
def test_add_vectors(self):
@ -92,7 +92,7 @@ class LoopsTest(jtu.JaxTestCase):
x = jnp.array([1., 2., 3.], dtype=jnp.float32)
y = jnp.array([4., 5., 6.], dtype=jnp.float32)
self.assertAllClose(jnp.add(x, y), add_vec(x, y), check_dtypes=True)
self.assertAllClose(jnp.add(x, y), add_vec(x, y))
def test_matmul(self):
def matmul(x, y):
@ -109,7 +109,7 @@ class LoopsTest(jtu.JaxTestCase):
x = jnp.array([[1., 2., 3.]], dtype=jnp.float32) # 1x3
y = jnp.array([[4.], [5.], [6.]], dtype=jnp.float32) # 3x1
self.assertAllClose(jnp.matmul(x, y), matmul(x, y), check_dtypes=True)
self.assertAllClose(jnp.matmul(x, y), matmul(x, y))
def test_reuse_range(self):
"""Ranges can be reused, as long as not nested in each other."""
@ -136,7 +136,7 @@ class LoopsTest(jtu.JaxTestCase):
s.out += inc
return s.out
self.assertAllClose(10. + 5 * (2. + 6 * 2.), f_op(2.), check_dtypes=True)
self.assertAllClose(10. + 5 * (2. + 6 * 2.), f_op(2.))
def test_example_doc(self):
"The example from the module docstring."
@ -169,8 +169,8 @@ class LoopsTest(jtu.JaxTestCase):
s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)
return s.arr
self.assertAllClose(f_expected(), f_op_jax(), check_dtypes=True)
self.assertAllClose(f_expected(), f_op_loops(), check_dtypes=True)
self.assertAllClose(f_expected(), f_op_jax())
self.assertAllClose(f_expected(), f_op_loops())
def test_loop_mutable_used_but_not_changed(self):
def f_op(inc):
@ -184,7 +184,7 @@ class LoopsTest(jtu.JaxTestCase):
return save_to_other_var
self.assertAllClose(10. + 5 * 2., f_op(2.), check_dtypes=True)
self.assertAllClose(10. + 5 * 2., f_op(2.))
def test_range_locations(self):
"""Ranges have locations."""
@ -257,7 +257,7 @@ class LoopsTest(jtu.JaxTestCase):
pass
return i
self.assertAllClose(4, f_op(4), check_dtypes=True)
self.assertAllClose(4, f_op(4))
def test_error_new_state_in_loop(self):
"""Error when creating new state in a loop."""
@ -281,13 +281,13 @@ class LoopsTest(jtu.JaxTestCase):
s.out += inc
return s.out
self.assertAllClose(16., f_op(0, 4, 4.), check_dtypes=True)
self.assertAllClose(16., f_op(0, 4, 4.))
# Ok to jit, as long as the start and end are static
self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.), check_dtypes=True)
self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.))
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.), check_dtypes=True)
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.))
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
self.assertAllClose(16., api.vmap(f_op)(jnp.zeros(10), jnp.ones(10), jnp.array([4.] * 10)), check_dtypes=True)
self.assertAllClose(16., api.vmap(f_op)(jnp.zeros(10), jnp.ones(10), jnp.array([4.] * 10)))
def test_cond(self):
def f_op(inc):
@ -297,8 +297,8 @@ class LoopsTest(jtu.JaxTestCase):
s.out += inc
return s.out
self.assertAllClose(10. + 2., f_op(2.), check_dtypes=True)
self.assertAllClose(10., f_op(-2.), check_dtypes=True)
self.assertAllClose(10. + 2., f_op(2.))
self.assertAllClose(10., f_op(-2.))
def test_cond_state(self):
"""Conditionals predicated on scope fields."""
@ -309,8 +309,8 @@ class LoopsTest(jtu.JaxTestCase):
s.out *= 2.
return s.out
self.assertAllClose(2. * 2., f_op(2.), check_dtypes=True)
self.assertAllClose(-2., f_op(-2.), check_dtypes=True)
self.assertAllClose(2. * 2., f_op(2.))
self.assertAllClose(-2., f_op(-2.))
def test_cond_nested(self):
"""Nested conditionals."""
@ -341,7 +341,7 @@ class LoopsTest(jtu.JaxTestCase):
return s.out
for init in [-1., 0., 9., 10.]:
self.assertAllClose(f_expected(init), f_op(init), check_dtypes=True)
self.assertAllClose(f_expected(init), f_op(init))
def test_error_cond_using_index_var(self):
@ -373,13 +373,13 @@ class LoopsTest(jtu.JaxTestCase):
out += 1.
return out
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True)
self.assertAllClose(f_expected(2.), f_op(2.))
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.))
self.assertAllClose(f_expected(1.), f_op(1.))
init_batch = np.array([1., 2., 3.], dtype=np.float32)
self.assertAllClose(np.array([f_expected(init) for init in init_batch],
dtype=np.float32),
api.vmap(f_op)(init_batch), check_dtypes=True)
api.vmap(f_op)(init_batch))
def test_error_while_cond_mutation(self):
"""Disallow mutation in the while conditional."""

View File

@ -601,8 +601,8 @@ class MaskingTest(jtu.JaxTestCase):
def test_slice_oob_indexing(self):
# https://github.com/google/jax/issues/2245
self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10], check_dtypes=True)
self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:], check_dtypes=True)
self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10])
self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:])
if __name__ == '__main__':
absltest.main()

View File

@ -51,7 +51,7 @@ class MultiBackendTest(jtu.JaxTestCase):
y = npr.uniform(size=(10,10))
z_host = np.matmul(x, y)
z = fun(x, y)
self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2)
self.assertAllClose(z, z_host, rtol=1e-2)
correct_platform = backend if backend else jtu.device_under_test()
self.assertEqual(z.device_buffer.platform(), correct_platform)
@ -73,7 +73,7 @@ class MultiBackendTest(jtu.JaxTestCase):
y = npr.uniform(size=(10,10))
z_host = np.matmul(x, y) + np.ones_like(x)
z = fun(x, y)
self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2)
self.assertAllClose(z, z_host, rtol=1e-2)
correct_platform = outer if outer else jtu.device_under_test()
self.assertEqual(z.device_buffer.platform(), correct_platform)

View File

@ -55,7 +55,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
def testSoftplusGradInf(self):
self.assertAllClose(
1., jax.grad(nn.softplus)(float('inf')), check_dtypes=True)
1., jax.grad(nn.softplus)(float('inf')))
def testSoftplusGradNegInf(self):
check_grads(nn.softplus, (-float('inf'),), order=1,
@ -91,7 +91,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
def testGluValue(self):
val = nn.glu(jnp.array([1.0, 0.0]))
self.assertAllClose(val, jnp.array([0.5]), check_dtypes=True)
self.assertAllClose(val, jnp.array([0.5]))
@parameterized.parameters(*itertools.product(
(jnp.float32, jnp.bfloat16, jnp.float16),
@ -120,33 +120,33 @@ class NNFunctionsTest(jtu.JaxTestCase):
expected = jnp.array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
self.assertAllClose(actual, expected, check_dtypes=True)
self.assertAllClose(actual, expected)
actual = nn.one_hot(jnp.array([1, 2, 0]), 3)
expected = jnp.array([[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]])
self.assertAllClose(actual, expected, check_dtypes=True)
self.assertAllClose(actual, expected)
def testOneHotOutOfBound(self):
actual = nn.one_hot(jnp.array([-1, 3]), 3)
expected = jnp.array([[0., 0., 0.],
[0., 0., 0.]])
self.assertAllClose(actual, expected, check_dtypes=True)
self.assertAllClose(actual, expected)
def testOneHotNonArrayInput(self):
actual = nn.one_hot([0, 1, 2], 3)
expected = jnp.array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
self.assertAllClose(actual, expected, check_dtypes=True)
self.assertAllClose(actual, expected)
def testOneHotCustomDtype(self):
actual = nn.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
expected = jnp.array([[True, False, False],
[False, True, False],
[False, False, True]])
self.assertAllClose(actual, expected, check_dtypes=True)
self.assertAllClose(actual, expected)
InitializerRecord = collections.namedtuple(
"InitializerRecord",

View File

@ -39,7 +39,7 @@ class OptimizerTests(jtu.JaxTestCase):
def _CheckFuns(self, optimizer, loss, x0, *args):
init_fun, update_fun, get_params = optimizer(*args)
opt_state = init_fun(x0)
self.assertAllClose(x0, get_params(opt_state), check_dtypes=True)
self.assertAllClose(x0, get_params(opt_state))
opt_state2 = update_fun(0, grad(loss)(x0), opt_state) # doesn't crash
self.assertEqual(tree_util.tree_structure(opt_state),
tree_util.tree_structure(opt_state2))
@ -294,7 +294,7 @@ class OptimizerTests(jtu.JaxTestCase):
J1 = jacrev(loss, argnums=(0,))(initial_params)
J2 = jacfwd(loss, argnums=(0,))(initial_params)
self.assertAllClose(J1, J2, check_dtypes=True, rtol=1e-6)
self.assertAllClose(J1, J2, rtol=1e-6)
def testUnpackPackRoundTrip(self):
opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)

View File

@ -87,7 +87,7 @@ class PapplyTest(jtu.JaxTestCase):
t = np.ones((5, 3))
ans = soft_pmap(*_papply(fun))(t)
expected = fun(t)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testLogSoftmax(self):
raise SkipTest("test doesn't pass yet") # TODO(frostig)
@ -113,7 +113,7 @@ class PapplyTest(jtu.JaxTestCase):
pfun, axis_name = _papply(jnp.add)
ans = soft_pmap(pfun, axis_name)(x, x)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testAddBroadcasting(self):
raise SkipTest("test doesn't pass yet") # TODO(frostig)
@ -126,7 +126,7 @@ class PapplyTest(jtu.JaxTestCase):
pfun, axis_name = _papply(fun)
ans = soft_pmap(pfun, axis_name)(x)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testMakeJaxprPapplyComposition(self):
raise SkipTest( # TODO(mattjj)
@ -142,7 +142,7 @@ class ParallelizeTest(jtu.JaxTestCase):
def dedup(self, arr, expected_rank):
if arr.ndim == expected_rank + 1:
for i in range(arr.shape[0] - 1):
self.assertAllClose(arr[i], arr[i + 1], check_dtypes=True)
self.assertAllClose(arr[i], arr[i + 1])
return arr[0]
else:
assert arr.ndim == expected_rank

View File

@ -217,7 +217,7 @@ class PmapTest(jtu.JaxTestCase):
f_expected = np.broadcast_to(x, mesh_shape)
f_ans = f(x, y)
self.assertAllClose(f_ans, f_expected, check_dtypes=True)
self.assertAllClose(f_ans, f_expected)
self.assertIsInstance(f_ans, pxla.ShardedDeviceArray)
# the output is actually replicated (has the same values in each device buffer)
# but out_axes is implicitly 0, so we shouldn't have replication in the
@ -226,7 +226,7 @@ class PmapTest(jtu.JaxTestCase):
g_expected = np.broadcast_to(x - np.sum(y, 0, keepdims=True), shape)
g_ans = g(x, y)
self.assertAllClose(g_ans, g_expected, check_dtypes=True)
self.assertAllClose(g_ans, g_expected)
self.assertIsInstance(g_ans, pxla.ShardedDeviceArray)
self.assertEqual(g_ans.sharding_spec.replication_factor, 1)
@ -302,7 +302,7 @@ class PmapTest(jtu.JaxTestCase):
ans = grad(lambda x: jnp.sum(splitjvp(x)))(x)
expected = grad(fun)(x)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testTwoArgsGrad(self):
def f(x, y):
@ -352,7 +352,7 @@ class PmapTest(jtu.JaxTestCase):
ans = grad(lambda x: jnp.sum(test_fun(x)))(x)
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
self.assertAllClose(ans, expected, check_dtypes=True, atol=1e-3)
self.assertAllClose(ans, expected, atol=1e-3)
def testShardedDeviceArrays(self):
f = lambda x: 2 * x
@ -466,11 +466,11 @@ class PmapTest(jtu.JaxTestCase):
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 1)
expected = x - expected_psum
ans = f1(x)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
expected = x - expected_psum + 1.
ans = f2(x)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
shape = (replicas // 2, 2, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
@ -482,7 +482,7 @@ class PmapTest(jtu.JaxTestCase):
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
expected = x - expected_psum
ans = f3(x)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testAxisGroups(self):
axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2))
@ -568,7 +568,7 @@ class PmapTest(jtu.JaxTestCase):
lambda x: lax.ppermute(x, "i", zip(range(num_devices), perm)), "i")
result = f(jnp.arange(num_devices, dtype=jnp.float32))
expected = jnp.asarray(perm, dtype=jnp.float32)
self.assertAllClose(result, expected, check_dtypes=True)
self.assertAllClose(result, expected)
@jtu.skip_on_devices("cpu", "gpu")
def testRule30(self):
@ -911,7 +911,7 @@ class PmapTest(jtu.JaxTestCase):
arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs)
r = pmap(lambda x: x + 1)(arr)
self.assertAllClose(r, arr + 1, check_dtypes=True)
self.assertAllClose(r, arr + 1)
self.assertEqual(len(r.device_buffers), 6)
@ignore_soft_pmap_warning()
@ -1114,8 +1114,7 @@ class PmapTest(jtu.JaxTestCase):
vals = list(range(500))
ndevices = xla_bridge.device_count()
self.assertAllClose(f(jnp.array([vals] * ndevices)),
jnp.array([sum(vals)] * ndevices),
check_dtypes=True)
jnp.array([sum(vals)] * ndevices))
def testPostProcessMap(self):
# code from https://github.com/google/jax/issues/2787
@ -1221,7 +1220,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testOneDevice(self):
if xla_bridge.device_count() == 1:
@ -1236,8 +1235,8 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
r0 = f0(x)
r1 = f1(x)
expected = np.expand_dims(np.dot(x.squeeze(), x.squeeze().T), 0)
self.assertAllClose(r0, expected, check_dtypes=True, atol=1e-6, rtol=1e-3)
self.assertAllClose(r1, expected, check_dtypes=True, atol=1e-6, rtol=1e-3)
self.assertAllClose(r0, expected, atol=1e-6, rtol=1e-3)
self.assertAllClose(r1, expected, atol=1e-6, rtol=1e-3)
def testNoDevicesError(self):
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[])
@ -1303,7 +1302,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
ndevices = xla_bridge.device_count()
ans = foo(jnp.ones((ndevices, 1)))
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices * 2
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testPmapInJit(self):
@jit
@ -1316,7 +1315,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
ndevices = xla_bridge.device_count()
ans = foo(jnp.ones((ndevices, 1)))
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices
self.assertAllClose(ans, expected, check_dtypes=True)
self.assertAllClose(ans, expected)
def testGradBasic(self):
@partial(pmap, axis_name='i', devices=xla_bridge.devices())

View File

@ -144,15 +144,15 @@ class LaxRandomTest(jtu.JaxTestCase):
if jtu.device_under_test() != "tpu":
bits8 = random._random_bits(key, 8, (3,))
expected8 = np.array([216, 115, 43], dtype=np.uint8)
self.assertArraysEqual(bits8, expected8, check_dtypes=True)
self.assertArraysEqual(bits8, expected8)
bits16 = random._random_bits(key, 16, (3,))
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
self.assertArraysEqual(bits16, expected16, check_dtypes=True)
self.assertArraysEqual(bits16, expected16)
bits32 = random._random_bits(key, 32, (3,))
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
self.assertArraysEqual(bits32, expected32, check_dtypes=True)
self.assertArraysEqual(bits32, expected32)
bits64 = random._random_bits(key, 64, (3,))
if FLAGS.jax_enable_x64:
@ -160,7 +160,7 @@ class LaxRandomTest(jtu.JaxTestCase):
7882654074788531506], dtype=np.uint64)
else:
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
self.assertArraysEqual(bits64, expected64, check_dtypes=True)
self.assertArraysEqual(bits64, expected64)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype)}
@ -227,7 +227,7 @@ class LaxRandomTest(jtu.JaxTestCase):
with self.assertWarns(FutureWarning):
perm2 = crand(key)
self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertAllClose(perm1, perm2)
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
@ -245,12 +245,11 @@ class LaxRandomTest(jtu.JaxTestCase):
perm1 = rand(key)
perm2 = crand(key)
self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertAllClose(perm1, perm2)
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
self.assertArraysAllClose(
x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype),
check_dtypes=True)
x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))
def testPermutationInteger(self):
key = random.PRNGKey(0)
@ -261,7 +260,7 @@ class LaxRandomTest(jtu.JaxTestCase):
perm1 = rand(key)
perm2 = crand(key)
self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertAllClose(perm1, perm2)
self.assertEqual(perm1.dtype, perm2.dtype)
self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely!
self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)
@ -380,7 +379,7 @@ class LaxRandomTest(jtu.JaxTestCase):
compiled_samples = crand(key, alpha)
for samples in [uncompiled_samples, compiled_samples]:
self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype), check_dtypes=True)
self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype))
alpha_sum = sum(alpha)
for i, a in enumerate(alpha):
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
@ -435,7 +434,7 @@ class LaxRandomTest(jtu.JaxTestCase):
pdf = scipy.stats.gamma.pdf(z, alpha)
expected_grad = -cdf_dot / pdf
self.assertAllClose(actual_grad, expected_grad, check_dtypes=True,
self.assertAllClose(actual_grad, expected_grad,
rtol=2e-2 if jtu.device_under_test() == "tpu" else 5e-4)
def testGammaGradType(self):
@ -670,28 +669,24 @@ class LaxRandomTest(jtu.JaxTestCase):
random.randint(k, (3, 3), 0, 8),
np.array([[7, 2, 6],
[2, 1, 0],
[6, 7, 7]], dtype='int64'),
check_dtypes=True)
[6, 7, 7]], dtype='int64'))
else:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8),
np.array([[2, 1, 3],
[6, 1, 5],
[6, 3, 4]], dtype='int32'),
check_dtypes=True)
[6, 3, 4]], dtype='int32'))
self.assertAllClose(
random.split(k, 4),
np.array([[2285895361, 1501764800],
[1518642379, 4090693311],
[ 433833334, 4221794875],
[ 839183663, 3740430601]], dtype='uint32'),
check_dtypes=True)
[ 839183663, 3740430601]], dtype='uint32'))
self.assertAllClose(
random.fold_in(k, 4),
np.array([2285895361, 433833334], dtype='uint32'),
check_dtypes=True)
np.array([2285895361, 433833334], dtype='uint32'))
if __name__ == "__main__":

View File

@ -103,11 +103,9 @@ class NdimageTest(jtu.JaxTestCase):
if dtype in float_dtypes:
epsilon = max([dtypes.finfo(dtypes.canonicalize_dtype(d)).eps
for d in [dtype, coords_dtype]])
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=100*epsilon,
check_dtypes=True)
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=100*epsilon)
else:
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0,
check_dtypes=True)
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0)
def testMapCoordinatesErrors(self):
x = onp.arange(5.0)
@ -137,7 +135,7 @@ class NdimageTest(jtu.JaxTestCase):
lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order)
osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order)
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker)
def testContinuousGradients(self):
# regression test for https://github.com/google/jax/issues/3024

View File

@ -64,7 +64,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-8}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jsp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format(
@ -87,7 +87,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-14}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jsp_fun, args_maker)
if __name__ == "__main__":

View File

@ -61,8 +61,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True,
rtol={onp.float64: 1e-14})
self._CompileAndCheck(lax_fun, args_maker, rtol={onp.float64: 1e-14})
@genNamedParametersNArgs(3, jtu.rand_default)
def testPoissonPmf(self, rng_factory, shapes, dtypes):
@ -80,7 +79,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_default)
def testBernoulliLogPmf(self, rng_factory, shapes, dtypes):
@ -97,7 +96,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_default)
def testGeomLogPmf(self, rng_factory, shapes, dtypes):
@ -114,7 +113,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5, jtu.rand_positive)
def testBetaLogPdf(self, rng_factory, shapes, dtypes):
@ -128,7 +127,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True,
self._CompileAndCheck(lax_fun, args_maker,
rtol={onp.float32: 2e-3, onp.float64: 1e-4})
@genNamedParametersNArgs(3, jtu.rand_default)
@ -145,7 +144,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2, jtu.rand_positive)
def testDirichletLogPdf(self, rng_factory, shapes, dtypes):
@ -162,7 +161,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_positive)
def testExponLogPdf(self, rng_factory, shapes, dtypes):
@ -176,7 +175,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4, jtu.rand_positive)
def testGammaLogPdf(self, rng_factory, shapes, dtypes):
@ -190,7 +189,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_positive)
def testLaplaceLogPdf(self, rng_factory, shapes, dtypes):
@ -206,7 +205,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_default)
def testLaplaceCdf(self, rng_factory, shapes, dtypes):
@ -222,7 +221,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol={onp.float32: 1e-5, onp.float64: 1e-6})
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1, jtu.rand_default)
def testLogisticCdf(self, rng_factory, shapes, dtypes):
@ -235,7 +234,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1, jtu.rand_default)
def testLogisticLogpdf(self, rng_factory, shapes, dtypes):
@ -248,7 +247,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1, jtu.rand_default)
def testLogisticPpf(self, rng_factory, shapes, dtypes):
@ -261,7 +260,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1, jtu.rand_default)
def testLogisticSf(self, rng_factory, shapes, dtypes):
@ -274,7 +273,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_default)
def testNormLogPdf(self, rng_factory, shapes, dtypes):
@ -290,7 +289,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_default)
@ -307,7 +306,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_default)
@ -324,7 +323,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_default)
@ -341,9 +340,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
return [q, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, rtol=3e-4)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(4, jtu.rand_positive)
@ -358,7 +356,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4, jtu.rand_default)
@ -375,7 +373,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3, jtu.rand_default)
@ -390,7 +388,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker)
def testIssue972(self):
self.assertAllClose(
@ -455,9 +453,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
lsp_stats.multivariate_normal.logpdf,
args_maker, check_dtypes=True, tol=1e-3)
args_maker, tol=1e-3)
self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
check_dtypes=True, rtol=1e-4, atol=1e-4)
rtol=1e-4, atol=1e-4)
if __name__ == "__main__":

View File

@ -127,18 +127,18 @@ class VectorizeTest(jtu.JaxTestCase):
b, a = center(jnp.arange(3))
self.assertEqual(a.shape, (3,))
self.assertEqual(b.shape, ())
self.assertAllClose(1.0, b, False)
self.assertAllClose(1.0, b, check_dtypes=False)
X = jnp.arange(12).reshape((3, 4))
b, a = center(X, axis=1)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (3,))
self.assertAllClose(jnp.mean(X, axis=1), b, True)
self.assertAllClose(jnp.mean(X, axis=1), b)
b, a = center(X, axis=0)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (4,))
self.assertAllClose(jnp.mean(X, axis=0), b, True)
self.assertAllClose(jnp.mean(X, axis=0), b)
if __name__ == "__main__":