mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs. Default to check_dtypes=True. Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense. No functional changes intended. * Fix a number of lax reference implementations to preserve types.
This commit is contained in:
parent
49a441f745
commit
fffdb2daa8
@ -32,7 +32,7 @@ neg = np.negative
|
||||
sign = np.sign
|
||||
floor = np.floor
|
||||
ceil = np.ceil
|
||||
round = lambda x: np.trunc(x + np.copysign(.5, x))
|
||||
round = lambda x: np.trunc(x + np.copysign(.5, x)).astype(x.dtype)
|
||||
nextafter = np.nextafter
|
||||
|
||||
is_finite = np.isfinite
|
||||
@ -47,7 +47,7 @@ cos = np.cos
|
||||
atan2 = np.arctan2
|
||||
|
||||
sqrt = np.sqrt
|
||||
rsqrt = lambda x: 1. / np.sqrt(x)
|
||||
rsqrt = lambda x: np.ones_like(x) / np.sqrt(x)
|
||||
square = np.square
|
||||
reciprocal = np.reciprocal
|
||||
tan = np.tan
|
||||
@ -60,16 +60,17 @@ asinh = np.arcsinh
|
||||
acosh = np.arccosh
|
||||
atanh = np.arctanh
|
||||
|
||||
betainc = scipy.special.betainc
|
||||
lgamma = scipy.special.gammaln
|
||||
digamma = scipy.special.digamma
|
||||
def betainc(a, b, x): return scipy.special.betainc(a, b, x).astype(x.dtype)
|
||||
def lgamma(x): return scipy.special.gammaln(x).astype(x.dtype)
|
||||
def digamma(x): return scipy.special.digamma(x).astype(x.dtype)
|
||||
igamma = scipy.special.gammainc
|
||||
igammac = scipy.special.gammaincc
|
||||
erf = scipy.special.erf
|
||||
erfc = scipy.special.erfc
|
||||
erf_inv = scipy.special.erfinv
|
||||
bessel_i0e = scipy.special.i0e
|
||||
bessel_i1e = scipy.special.i1e
|
||||
def erf(x): return scipy.special.erf(x).astype(x.dtype)
|
||||
def erfc(x): return scipy.special.erfc(x).astype(x.dtype)
|
||||
def erf_inv(x): return scipy.special.erfinv(x).astype(x.dtype)
|
||||
|
||||
def bessel_i0e(x): return scipy.special.i0e(x).astype(x.dtype)
|
||||
def bessel_i1e(x): return scipy.special.i1e(x).astype(x.dtype)
|
||||
|
||||
real = np.real
|
||||
imag = np.imag
|
||||
@ -150,7 +151,7 @@ def bitcast_convert_type(operand, dtype):
|
||||
return np.asarray(operand).view(dtype)
|
||||
|
||||
def clamp(min, operand, max):
|
||||
return np.clip(operand, np.clip(min, None, max), max)
|
||||
return np.clip(operand, np.clip(min, None, max), max).astype(operand.dtype)
|
||||
|
||||
def concatenate(operands, dimension):
|
||||
return np.concatenate(operands, axis=dimension)
|
||||
@ -295,8 +296,6 @@ def sort_key_val(keys, values, dimension=-1):
|
||||
idxs[dimension] = np.argsort(keys, axis=dimension)
|
||||
return keys[tuple(idxs)], values[tuple(idxs)]
|
||||
|
||||
# TODO untake
|
||||
|
||||
### conv util
|
||||
|
||||
def _conv(lhs, rhs, window_strides, pads):
|
||||
|
@ -719,13 +719,14 @@ class JaxTestCase(parameterized.TestCase):
|
||||
def rng(self):
|
||||
return self._rng
|
||||
|
||||
def assertArraysEqual(self, x, y, check_dtypes):
|
||||
def assertArraysEqual(self, x, y, *, check_dtypes=True):
|
||||
"""Assert that x and y arrays are exactly equal."""
|
||||
if check_dtypes:
|
||||
self.assertDtypesMatch(x, y)
|
||||
np.testing.assert_equal(x, y)
|
||||
|
||||
def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None):
|
||||
def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None,
|
||||
rtol=None):
|
||||
"""Assert that x and y are close (up to numerical tolerances)."""
|
||||
self.assertEqual(x.shape, y.shape)
|
||||
atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol))
|
||||
@ -740,18 +741,20 @@ class JaxTestCase(parameterized.TestCase):
|
||||
if FLAGS.jax_enable_x64:
|
||||
self.assertEqual(_dtype(x), _dtype(y))
|
||||
|
||||
def assertAllClose(self, x, y, check_dtypes, atol=None, rtol=None):
|
||||
def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None):
|
||||
"""Assert that x and y, either arrays or nested tuples/lists, are close."""
|
||||
if isinstance(x, dict):
|
||||
self.assertIsInstance(y, dict)
|
||||
self.assertEqual(set(x.keys()), set(y.keys()))
|
||||
for k in x.keys():
|
||||
self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol)
|
||||
self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol,
|
||||
rtol=rtol)
|
||||
elif is_sequence(x) and not hasattr(x, '__array__'):
|
||||
self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
|
||||
self.assertEqual(len(x), len(y))
|
||||
for x_elt, y_elt in zip(x, y):
|
||||
self.assertAllClose(x_elt, y_elt, check_dtypes, atol=atol, rtol=rtol)
|
||||
self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol,
|
||||
rtol=rtol)
|
||||
elif hasattr(x, '__array__') or np.isscalar(x):
|
||||
self.assertTrue(hasattr(y, '__array__') or np.isscalar(y))
|
||||
if check_dtypes:
|
||||
@ -772,7 +775,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
self.assertMultiLineEqual(expected_clean, what_clean,
|
||||
msg="Found\n{}\nExpecting\n{}".format(what, expected))
|
||||
|
||||
def _CompileAndCheck(self, fun, args_maker, check_dtypes,
|
||||
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True,
|
||||
rtol=None, atol=None):
|
||||
"""Helper method for running JAX compilation and allclose assertions."""
|
||||
args = args_maker()
|
||||
@ -802,8 +805,10 @@ class JaxTestCase(parameterized.TestCase):
|
||||
python_should_be_executing = False
|
||||
compiled_ans = cfun(*args)
|
||||
|
||||
self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol)
|
||||
self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)
|
||||
self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes,
|
||||
atol=atol, rtol=rtol)
|
||||
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
|
||||
atol=atol, rtol=rtol)
|
||||
|
||||
args = args_maker()
|
||||
|
||||
@ -813,10 +818,11 @@ class JaxTestCase(parameterized.TestCase):
|
||||
python_should_be_executing = False
|
||||
compiled_ans = cfun(*args)
|
||||
|
||||
self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)
|
||||
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
|
||||
atol=atol, rtol=rtol)
|
||||
|
||||
def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
|
||||
check_dtypes=False, tol=None):
|
||||
check_dtypes=True, tol=None):
|
||||
args = args_maker()
|
||||
lax_ans = lax_op(*args)
|
||||
numpy_ans = numpy_reference_op(*args)
|
||||
|
@ -462,7 +462,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_grad_and_aux_basic(self):
|
||||
g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
|
||||
self.assertAllClose(g, grad(lambda x: x**3)(3.), check_dtypes=True)
|
||||
self.assertAllClose(g, grad(lambda x: x**3)(3.))
|
||||
self.assertAllClose(aux, [9.], check_dtypes=False)
|
||||
|
||||
def test_grad_and_aux_nested(self):
|
||||
@ -557,7 +557,7 @@ class APITest(jtu.JaxTestCase):
|
||||
res2 = api.jit(inner)(5.)
|
||||
return res1 + res2
|
||||
|
||||
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)), check_dtypes=True)
|
||||
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)))
|
||||
|
||||
|
||||
def test_complex_grad_raises_error(self):
|
||||
@ -596,7 +596,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
ans = jacrev(f)(zs)
|
||||
expected = grad(f)(zs)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def test_complex_input_jacfwd_raises_error(self):
|
||||
self.assertRaises(TypeError, lambda: jacfwd(lambda x: jnp.sin(x))(1 + 2j))
|
||||
@ -980,7 +980,7 @@ class APITest(jtu.JaxTestCase):
|
||||
futures = [executor.submit(partial(f, x)) for x in xs]
|
||||
ys = [f.result() for f in futures]
|
||||
for x, y in zip(xs, ys):
|
||||
self.assertAllClose(x, y, check_dtypes=True)
|
||||
self.assertAllClose(x, y)
|
||||
|
||||
def test_concurrent_jit(self):
|
||||
@jit
|
||||
@ -992,7 +992,7 @@ class APITest(jtu.JaxTestCase):
|
||||
futures = [executor.submit(partial(f, x)) for x in xs]
|
||||
ys = [f.result() for f in futures]
|
||||
for x, y in zip(xs, ys):
|
||||
self.assertAllClose(x * 2 - 3., y, check_dtypes=True)
|
||||
self.assertAllClose(x * 2 - 3., y)
|
||||
|
||||
def test_dtype_warning(self):
|
||||
# cf. issue #1230
|
||||
@ -1054,7 +1054,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
|
||||
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
|
||||
self.assertAllClose(out1, out2, check_dtypes=True)
|
||||
self.assertAllClose(out1, out2)
|
||||
|
||||
def test_vmap_in_axes_tree_prefix_error(self):
|
||||
# https://github.com/google/jax/issues/795
|
||||
@ -2007,11 +2007,10 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
f.defjvp(f_jvp)
|
||||
|
||||
x = 3.
|
||||
self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(f(x), jnp.sin(x))
|
||||
self.assertAllClose(api.jvp(f, (x,), (1.,)),
|
||||
(jnp.sin(x), 2 * jnp.cos(x)),
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x), check_dtypes=True)
|
||||
(jnp.sin(x), 2 * jnp.cos(x)))
|
||||
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
|
||||
|
||||
def test_invariance(self):
|
||||
@api.custom_jvp
|
||||
@ -2052,8 +2051,8 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
return f(x), 3 * g
|
||||
f.defjvp(f_jvp)
|
||||
x = 2.
|
||||
self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(f(-x), jnp.cos(-x), check_dtypes=True)
|
||||
self.assertAllClose(f(x), jnp.sin(x))
|
||||
self.assertAllClose(f(-x), jnp.cos(-x))
|
||||
self.assertAllClose(api.jvp(f, (x,), (1.,)),
|
||||
(jnp.sin(x), 2.),
|
||||
check_dtypes=False)
|
||||
@ -2079,29 +2078,24 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
xx = jnp.arange(6.).reshape(2, 3)
|
||||
|
||||
# vmap of f
|
||||
self.assertAllClose(api.vmap(f)(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx), check_dtypes=True)
|
||||
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
|
||||
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
|
||||
|
||||
# vmap of jvp of f
|
||||
self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x),
|
||||
(jnp.sin(x), 2 * jnp.cos(x) * x),
|
||||
check_dtypes=True)
|
||||
(jnp.sin(x), 2 * jnp.cos(x) * x))
|
||||
self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx),
|
||||
(jnp.sin(xx), 2 * jnp.cos(xx) * xx),
|
||||
check_dtypes=True)
|
||||
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
|
||||
|
||||
# jvp of vmap of f
|
||||
self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)),
|
||||
(jnp.sin(x), 2 * jnp.cos(x) * x),
|
||||
check_dtypes=True)
|
||||
(jnp.sin(x), 2 * jnp.cos(x) * x))
|
||||
self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)),
|
||||
(jnp.sin(xx), 2 * jnp.cos(xx) * xx),
|
||||
check_dtypes=True)
|
||||
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
|
||||
|
||||
# vmap of jvp of vmap of f
|
||||
self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx),
|
||||
(jnp.sin(xx), 2 * jnp.cos(xx) * xx),
|
||||
check_dtypes=True)
|
||||
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
|
||||
|
||||
def test_jit(self):
|
||||
@api.custom_jvp
|
||||
@ -2116,8 +2110,8 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
x = 3.
|
||||
|
||||
# jit
|
||||
self.assertAllClose(api.jit(f)(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
|
||||
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
|
||||
|
||||
# jit of jvp
|
||||
self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x),
|
||||
@ -2139,7 +2133,7 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']}
|
||||
f.defjvp(f_jvp)
|
||||
x = {'a': 3.}
|
||||
self.assertAllClose(f(x)['b'], jnp.sin(x['a']), check_dtypes=True)
|
||||
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
|
||||
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
||||
({'b': jnp.sin(x['a'])},
|
||||
{'b': 2 * jnp.cos(x['a']) * x['a']}),
|
||||
@ -2491,11 +2485,10 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
x = 3.
|
||||
self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x), check_dtypes=True)
|
||||
self.assertAllClose(f(x), jnp.sin(x))
|
||||
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
|
||||
self.assertAllClose(api.value_and_grad(f)(x),
|
||||
(jnp.sin(x), 2 * jnp.cos(x)),
|
||||
check_dtypes=True)
|
||||
(jnp.sin(x), 2 * jnp.cos(x)))
|
||||
|
||||
def test_invariance(self):
|
||||
@api.custom_vjp
|
||||
@ -2539,8 +2532,8 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
return (3 * g,)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
x = 2.
|
||||
self.assertAllClose(f(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(f(-x), jnp.cos(-x), check_dtypes=True)
|
||||
self.assertAllClose(f(x), jnp.sin(x))
|
||||
self.assertAllClose(f(-x), jnp.cos(-x))
|
||||
self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.),
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.),
|
||||
@ -2562,33 +2555,26 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
xx = jnp.arange(6.).reshape(2, 3)
|
||||
|
||||
# vmap of f
|
||||
self.assertAllClose(api.vmap(f)(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx), check_dtypes=True)
|
||||
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
|
||||
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
|
||||
|
||||
# vmap of grad of f
|
||||
self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x),
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x))
|
||||
self.assertAllClose(api.vmap(api.value_and_grad(f))(x),
|
||||
(jnp.sin(x), 2 * jnp.cos(x)),
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx),
|
||||
check_dtypes=True)
|
||||
(jnp.sin(x), 2 * jnp.cos(x)))
|
||||
self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx))
|
||||
self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx),
|
||||
(jnp.sin(xx), 2 * jnp.cos(xx)),
|
||||
check_dtypes=True)
|
||||
(jnp.sin(xx), 2 * jnp.cos(xx)))
|
||||
|
||||
# grad of vmap of f
|
||||
self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x),
|
||||
2 * jnp.cos(x),
|
||||
check_dtypes=True)
|
||||
2 * jnp.cos(x))
|
||||
self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx),
|
||||
2 * jnp.cos(xx),
|
||||
check_dtypes=True)
|
||||
2 * jnp.cos(xx))
|
||||
|
||||
# vmap of grad of vmap of f
|
||||
self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx),
|
||||
2 * jnp.cos(xx),
|
||||
check_dtypes=True)
|
||||
2 * jnp.cos(xx))
|
||||
|
||||
def test_jit(self):
|
||||
@api.custom_vjp
|
||||
@ -2603,8 +2589,8 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
x = 3.
|
||||
|
||||
# jit
|
||||
self.assertAllClose(api.jit(f)(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x), check_dtypes=True)
|
||||
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
|
||||
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
|
||||
|
||||
# jit of grad
|
||||
self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x),
|
||||
@ -2625,10 +2611,9 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
return ({'a': 2 * cos_x * g['b']},)
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
x = {'a': 3.}
|
||||
self.assertAllClose(f(x)['b'], jnp.sin(x['a']), check_dtypes=True)
|
||||
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
|
||||
self.assertAllClose(api.grad(lambda x: f(x)['b'])(x),
|
||||
{'a': 2 * jnp.cos(x['a'])},
|
||||
check_dtypes=True)
|
||||
{'a': 2 * jnp.cos(x['a'])})
|
||||
|
||||
def test_jvp_error(self):
|
||||
@api.custom_vjp
|
||||
@ -2680,7 +2665,7 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
ans = api.grad(api.grad(foo))(3.)
|
||||
expected = -2. * jnp.sin(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def test_initial_style_vmap(self):
|
||||
@api.custom_vjp
|
||||
@ -2851,7 +2836,7 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
||||
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
||||
self.assertAllClose(val_ans, 9., check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 12., check_dtypes=True)
|
||||
self.assertAllClose(grad_ans, 12.)
|
||||
|
||||
def test_defvjp_all_higher_order_revmode(self):
|
||||
foo_p = Primitive('foo')
|
||||
|
@ -69,7 +69,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
x = jnp.array(np)
|
||||
dlpack = jax.dlpack.to_dlpack(x)
|
||||
y = jax.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np.astype(x.dtype), y, check_dtypes=True)
|
||||
self.assertAllClose(np.astype(x.dtype), y)
|
||||
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
"DLPack tensor may be consumed at most once",
|
||||
@ -89,7 +89,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
x = x.cuda() if jtu.device_under_test() == "gpu" else x
|
||||
dlpack = torch.utils.dlpack.to_dlpack(x)
|
||||
y = jax.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y, check_dtypes=True)
|
||||
self.assertAllClose(np, y)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
@ -104,7 +104,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
x = jnp.array(np)
|
||||
dlpack = jax.dlpack.to_dlpack(x)
|
||||
y = torch.utils.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y.numpy(), check_dtypes=True)
|
||||
self.assertAllClose(np, y.numpy())
|
||||
|
||||
|
||||
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
@ -128,7 +128,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
z = cupy.asarray(y)
|
||||
self.assertEqual(y.__cuda_array_interface__["data"][0],
|
||||
z.__cuda_array_interface__["data"][0])
|
||||
self.assertAllClose(x, cupy.asnumpy(z), check_dtypes=True)
|
||||
self.assertAllClose(x, cupy.asnumpy(z))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -195,7 +195,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
ans = vmap(lambda x: x > 1.0)(x)
|
||||
expected_ans = x > 1.0
|
||||
self.assertAllClose(ans, expected_ans, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected_ans)
|
||||
|
||||
def testNpMaximumPerExampleGrad(self):
|
||||
R = np.random.RandomState(0).randn
|
||||
@ -226,35 +226,35 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
||||
ans = vmap(fun)(x, y)
|
||||
expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
x = R(3, 4, 10, 5)
|
||||
y = R(3, 10, 5, 6)
|
||||
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
||||
ans = vmap(fun, in_axes=(2, 1))(x, y)
|
||||
expected = np.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
x = R(3, 4, 5, 10)
|
||||
y = R(3, 5, 6)
|
||||
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
||||
ans = vmap(fun, in_axes=(3, None))(x, y)
|
||||
expected = np.stack([fun(x[..., i], y) for i in range(10)])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
x = R(3, 4, 5)
|
||||
y = R(3, 5, 10, 6)
|
||||
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
||||
ans = vmap(fun, in_axes=(None, 2))(x, y)
|
||||
expected = np.stack([fun(x, y[..., i, :]) for i in range(10)])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
x = R(4)
|
||||
y = R(4, 10)
|
||||
fun = lambda x, y: lax.dot_general(x, y, [((0,), (0,)), ((), ())])
|
||||
ans = vmap(fun, in_axes=(None, 1))(x, y)
|
||||
expected = np.stack([fun(x, y[..., i]) for i in range(10)])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testDot(self):
|
||||
# these tests are based on @shoyer's notebook studying gufuncs
|
||||
@ -348,7 +348,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]]))
|
||||
expected = jnp.array([True, False])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testHessian(self):
|
||||
@ -417,46 +417,44 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
v = np.arange(12)[::-1].reshape(3, 4)
|
||||
|
||||
sv = vmap(partial(lax.sort, dimension=0), (0,))(v)
|
||||
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sv, v[:, ::-1])
|
||||
|
||||
sv = vmap(partial(lax.sort, dimension=-1), (0,))(v)
|
||||
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sv, v[:, ::-1])
|
||||
|
||||
sv = vmap(partial(lax.sort, dimension=0), (1,))(v)
|
||||
self.assertAllClose(sv, v[::-1, :].T, check_dtypes=True)
|
||||
self.assertAllClose(sv, v[::-1, :].T)
|
||||
|
||||
sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v)
|
||||
self.assertAllClose(sv, v[::-1, :], check_dtypes=True)
|
||||
self.assertAllClose(sv, v[::-1, :])
|
||||
|
||||
def testSortKeyVal(self):
|
||||
k = np.arange(12)[::-1].reshape(3, 4)
|
||||
v = np.random.RandomState(0).permutation(12).reshape(3, 4)
|
||||
|
||||
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
|
||||
self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sk, k[:, ::-1])
|
||||
self.assertAllClose(sv, v[:, ::-1])
|
||||
|
||||
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
|
||||
self.assertAllClose(sk, k[::-1, :], check_dtypes=True)
|
||||
self.assertAllClose(sv, v[::-1, :], check_dtypes=True)
|
||||
self.assertAllClose(sk, k[::-1, :])
|
||||
self.assertAllClose(sv, v[::-1, :])
|
||||
|
||||
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
|
||||
self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sk, k[:, ::-1])
|
||||
self.assertAllClose(sv, v[:, ::-1])
|
||||
|
||||
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
|
||||
self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sk, k[:, ::-1])
|
||||
self.assertAllClose(sv, v[:, ::-1])
|
||||
|
||||
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v)
|
||||
self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)),
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)))
|
||||
self.assertAllClose(sv, v[:, ::-1])
|
||||
|
||||
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0])
|
||||
self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
|
||||
self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)),
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(sk, k[:, ::-1])
|
||||
self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)))
|
||||
|
||||
def testConvGeneralDilated(self):
|
||||
W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32)
|
||||
@ -474,7 +472,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
||||
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
|
||||
per_example_direct = f(W, X)
|
||||
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
|
||||
self.assertAllClose(per_example, per_example_direct)
|
||||
|
||||
# Test gradients.
|
||||
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
||||
@ -484,7 +482,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
per_example_direct += [
|
||||
jnp.reshape(g, (1,) + g.shape)]
|
||||
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
|
||||
self.assertAllClose(per_example, per_example_direct, check_dtypes=True,
|
||||
self.assertAllClose(per_example, per_example_direct,
|
||||
rtol=2e-2)
|
||||
|
||||
def testConvGeneralDilatedBatchNotMajor(self):
|
||||
@ -503,7 +501,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
(5, 5, 21, 4))
|
||||
per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)),
|
||||
(5, 21, 5, 1)))
|
||||
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
|
||||
self.assertAllClose(per_example, per_example_direct)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_op={}".format(name), "op": op, "unit": unit}
|
||||
@ -526,7 +524,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
||||
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
|
||||
per_example_direct = f(W, X)
|
||||
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
|
||||
self.assertAllClose(per_example, per_example_direct)
|
||||
|
||||
# Test gradients.
|
||||
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
||||
@ -536,7 +534,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
per_example_direct += [
|
||||
jnp.reshape(g, (1,) + g.shape)]
|
||||
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
|
||||
self.assertAllClose(per_example, per_example_direct, check_dtypes=True,
|
||||
self.assertAllClose(per_example, per_example_direct,
|
||||
rtol=5e-2)
|
||||
|
||||
def testSumPool(self):
|
||||
@ -557,7 +555,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
||||
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
|
||||
per_example_direct = f(W, X)
|
||||
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
|
||||
self.assertAllClose(per_example, per_example_direct)
|
||||
|
||||
# Test gradients.
|
||||
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
||||
@ -567,14 +565,13 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
per_example_direct += [
|
||||
jnp.reshape(g, (1,) + g.shape)]
|
||||
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
|
||||
self.assertAllClose(per_example, per_example_direct, check_dtypes=True,
|
||||
self.assertAllClose(per_example, per_example_direct,
|
||||
rtol=3e-2)
|
||||
|
||||
def testCumProd(self):
|
||||
x = jnp.arange(9).reshape(3, 3) + 1
|
||||
y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
|
||||
self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y,
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y)
|
||||
|
||||
def testSelect(self):
|
||||
pred = np.array([True, False])
|
||||
@ -582,7 +579,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
on_false = np.array([2, 3])
|
||||
ans = vmap(lax.select)(pred, on_true, on_false)
|
||||
expected = np.array([0, 3])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
pred = np.array([False, True])
|
||||
on_true = np.array([0, 1])
|
||||
@ -590,28 +587,28 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
|
||||
expected = np.array([[2, 3],
|
||||
[0, 1]])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
pred = True
|
||||
on_true = np.array([0, 1], np.float32)
|
||||
on_false = np.array(3, np.float32)
|
||||
ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
|
||||
expected = np.array([0, 1], np.float32)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
pred = np.array([False, True])
|
||||
on_true = np.array([0, 1], np.float32)
|
||||
on_false = np.array(3, np.float32)
|
||||
ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
|
||||
expected = np.array([3, 1], np.float32)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
pred = np.array([False, True])
|
||||
on_true = np.array([2], np.float32)
|
||||
on_false = np.array([[3, 4]], np.float32)
|
||||
ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
|
||||
expected = np.array([[3, 2]], np.float32)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testLaxLinalgCholesky(self):
|
||||
a = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32)
|
||||
@ -637,17 +634,17 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
ans = vmap(lax_linalg.triangular_solve, in_axes=(1, 2))(a, b)
|
||||
expected = np.stack(
|
||||
[lax_linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
ans = vmap(lax_linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
|
||||
expected = np.stack(
|
||||
[lax_linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
ans = vmap(lax_linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0])
|
||||
expected = np.stack(
|
||||
[lax_linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
||||
@ -987,7 +984,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
g = jax.jit(jax.pmap(f))
|
||||
ans = g(index1=np.asarray([1]), index2=np.asarray([2]))
|
||||
expected = g(np.asarray([1]), np.asarray([2]))
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -69,11 +69,11 @@ class CallbackTest(jtu.JaxTestCase):
|
||||
return x * 2
|
||||
|
||||
x = jnp.array([2.0, 4.0])
|
||||
self.assertAllClose(f(x), jnp.array([4.0, 8.0]), True)
|
||||
self.assertAllClose(f(x), jnp.array([4.0, 8.0]))
|
||||
|
||||
self.assertAllClose(
|
||||
rewrite(f, {lax.mul_p: lambda x, y: x + y})(x),
|
||||
jnp.array([4.0, 6.0]), True)
|
||||
jnp.array([4.0, 6.0]))
|
||||
|
||||
def testRewriteJIT(self):
|
||||
def f(x):
|
||||
@ -83,11 +83,11 @@ class CallbackTest(jtu.JaxTestCase):
|
||||
return g(x)
|
||||
|
||||
x = jnp.array([2.0, 4.0])
|
||||
self.assertAllClose(f(x), jnp.array([4.0, 8.0]), True)
|
||||
self.assertAllClose(f(x), jnp.array([4.0, 8.0]))
|
||||
|
||||
self.assertAllClose(
|
||||
rewrite(f, {lax.mul_p: lambda x, y: x + y})(x),
|
||||
jnp.array([4.0, 6.0]), True)
|
||||
jnp.array([4.0, 6.0]))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -109,7 +109,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
# Test gradient for differentiable types.
|
||||
if dtype in (float_dtypes if real and not inverse else inexact_dtypes):
|
||||
# TODO(skye): can we be more precise?
|
||||
@ -169,7 +169,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
|
||||
@ -228,7 +228,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
|
||||
@ -279,7 +279,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
# Test gradient for differentiable types.
|
||||
if dtype in inexact_dtypes:
|
||||
tol = 0.15 # TODO(skye): can we be more precise?
|
||||
@ -323,7 +323,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
# Test gradient for differentiable types.
|
||||
if dtype in inexact_dtypes:
|
||||
tol = 0.15 # TODO(skye): can we be more precise?
|
||||
@ -362,7 +362,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes)
|
||||
np_fn = lambda arg: np.fft.fftshift(arg, axes=axes)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "dtype={}_axes={}".format(
|
||||
@ -377,7 +377,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes)
|
||||
np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -148,7 +148,7 @@ class HostCallbackTest(jtu.JaxTestCase):
|
||||
self.assertEqual("", testing_stream.output)
|
||||
|
||||
with hcb.outfeed_receiver():
|
||||
self.assertAllClose((5. * 2.) ** 2, fun1(5.), check_dtypes=True)
|
||||
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: a * 2
|
||||
10.00
|
||||
@ -232,7 +232,7 @@ what: x3
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
res = jit_fun1(5.)
|
||||
|
||||
self.assertAllClose(6. * 5., res, check_dtypes=True)
|
||||
self.assertAllClose(6. * 5., res)
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: here
|
||||
10.00""", testing_stream.output)
|
||||
@ -258,7 +258,7 @@ what: here
|
||||
self.assertEqual("", testing_stream.output)
|
||||
|
||||
with hcb.outfeed_receiver():
|
||||
self.assertAllClose(5, api.jit(func)(5), check_dtypes=True)
|
||||
self.assertAllClose(5, api.jit(func)(5))
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
42""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
@ -512,7 +512,7 @@ where: 3
|
||||
if with_jit:
|
||||
func = api.jit(func)
|
||||
res = func(1)
|
||||
self.assertAllClose(jnp.array([1, 2, 3]), res, check_dtypes=True)
|
||||
self.assertAllClose(jnp.array([1, 2, 3]), res)
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
@ -564,7 +564,7 @@ where: 10
|
||||
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
res = jit_fun1(args)
|
||||
# self.assertAllClose(args, res, check_dtypes=True)
|
||||
# self.assertAllClose(args, res)
|
||||
|
||||
def test_jit_large(self):
|
||||
arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
|
||||
|
@ -44,7 +44,7 @@ class InfeedTest(jax.test_util.JaxTestCase):
|
||||
device = jax.local_devices()[0]
|
||||
device.transfer_to_infeed((y,))
|
||||
device.transfer_to_infeed((z,))
|
||||
self.assertAllClose(f(x), x + y + z, check_dtypes=True)
|
||||
self.assertAllClose(f(x), x + y + z)
|
||||
|
||||
def testInfeedThenOutfeed(self):
|
||||
@jax.jit
|
||||
@ -64,7 +64,7 @@ class InfeedTest(jax.test_util.JaxTestCase):
|
||||
out, = device.transfer_from_outfeed(
|
||||
xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent())
|
||||
execution.join()
|
||||
self.assertAllClose(out, y + onp.float32(1), check_dtypes=True)
|
||||
self.assertAllClose(out, y + onp.float32(1))
|
||||
|
||||
def testInfeedThenOutfeedInALoop(self):
|
||||
def doubler(_, token):
|
||||
@ -87,7 +87,7 @@ class InfeedTest(jax.test_util.JaxTestCase):
|
||||
device.transfer_to_infeed((x,))
|
||||
y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,))
|
||||
.with_major_to_minor_layout_if_absent())
|
||||
self.assertAllClose(y, x * onp.float32(2), check_dtypes=True)
|
||||
self.assertAllClose(y, x * onp.float32(2))
|
||||
execution.join()
|
||||
|
||||
|
||||
|
@ -159,11 +159,9 @@ class JetTest(jtu.JaxTestCase):
|
||||
|
||||
atol = 1e-4
|
||||
rtol = 1e-4
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol)
|
||||
|
||||
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
|
@ -404,7 +404,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertEqual(count(3), 3)
|
||||
self.assertEqual(count(4), 6)
|
||||
for args_maker in [lambda: [2], lambda: [3], lambda: [4]]:
|
||||
self._CompileAndCheck(count, args_maker, True)
|
||||
self._CompileAndCheck(count, args_maker)
|
||||
|
||||
def testForiLoopClosure(self):
|
||||
def count(num):
|
||||
@ -1356,13 +1356,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
xs = jnp.arange(10)
|
||||
expected = xs ** 2
|
||||
actual = lax.map(f, xs)
|
||||
self.assertAllClose(actual, expected, check_dtypes=True)
|
||||
self.assertAllClose(actual, expected)
|
||||
|
||||
def testMapEmpty(self):
|
||||
# https://github.com/google/jax/issues/2412
|
||||
ans = lax.map(lambda x: x * x, jnp.array([]))
|
||||
expected = jnp.array([])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testCaching(self):
|
||||
def cond(x):
|
||||
@ -1598,7 +1598,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
actual = api.jit(linear_solve)(a, b)
|
||||
expected = jnp.linalg.solve(a, b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
def test_custom_root_with_custom_linear_solve(self):
|
||||
|
||||
@ -1620,11 +1620,11 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
actual = linear_solve(jnp.dot(a, a.T), b)
|
||||
expected = jnp.linalg.solve(jnp.dot(a, a.T), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
actual = api.jit(linear_solve)(jnp.dot(a, a.T), b)
|
||||
expected = jnp.linalg.solve(jnp.dot(a, a.T), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
jtu.check_grads(lambda x, y: linear_solve(jnp.dot(x, x.T), y),
|
||||
(a, b), order=2, rtol={jnp.float32: 1e-2})
|
||||
@ -1670,12 +1670,12 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
expected = jnp.linalg.solve(a, b)
|
||||
actual = api.jit(linear_solve)(a, b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
c = rng.randn(3, 2)
|
||||
expected = jnp.linalg.solve(a, c)
|
||||
actual = api.vmap(linear_solve, (None, 1), 1)(a, c)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve_zeros(self):
|
||||
@ -1723,7 +1723,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
b = rng.randn(2)
|
||||
expected = jnp.linalg.solve(jnp.exp(a), jnp.cos(b))
|
||||
actual = build_and_solve(a, b)
|
||||
self.assertAllClose(expected, actual, atol=1e-5, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual, atol=1e-5)
|
||||
jtu.check_grads(build_and_solve, (a, b), atol=1e-5, order=2,
|
||||
rtol={jnp.float32: 6e-2, jnp.float64: 2e-3})
|
||||
|
||||
@ -1749,10 +1749,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
expected = jnp.linalg.solve(np.asarray(posify(a)), b)
|
||||
actual = positive_definite_solve(posify(a), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
actual = api.jit(positive_definite_solve)(posify(a), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
# numerical gradients are only well defined if ``a`` is guaranteed to be
|
||||
# positive definite.
|
||||
@ -1798,7 +1798,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
expected = jnp.linalg.solve(a, b)
|
||||
actual = linear_solve(a, b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3)
|
||||
|
||||
|
@ -33,7 +33,7 @@ class EinsumTest(jtu.JaxTestCase):
|
||||
def _check(self, s, *ops):
|
||||
a = np.einsum(s, *ops)
|
||||
b = jnp.einsum(s, *ops, precision=lax.Precision.HIGHEST)
|
||||
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4, check_dtypes=True)
|
||||
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_three_operands_1(self):
|
||||
r = self.rng()
|
||||
|
@ -402,7 +402,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
fun = lambda x: x[indexer]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters({
|
||||
"testcase_name":
|
||||
@ -503,7 +503,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
return x[indexer]
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "{}_inshape={}_indexer={}"
|
||||
@ -553,7 +553,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), indexer]
|
||||
fun = lambda x, idx: jnp.asarray(x)[idx]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "{}_inshape={}_indexer={}"
|
||||
@ -635,7 +635,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
|
||||
return jnp.asarray(x)[idx]
|
||||
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
def testAdvancedIndexingManually(self):
|
||||
x = onp.random.RandomState(0).randn(3, 4, 5)
|
||||
@ -647,7 +647,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
a1 = op(x, index_array)
|
||||
a2 = cop(x, index_array)
|
||||
|
||||
self.assertAllClose(a1, a2, check_dtypes=True)
|
||||
self.assertAllClose(a1, a2)
|
||||
|
||||
op = lambda x, index_array: x[..., index_array, :, index_array, None]
|
||||
cop = api.jit(op)
|
||||
@ -655,7 +655,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
a1 = op(x, index_array)
|
||||
a2 = cop(x, index_array)
|
||||
|
||||
self.assertAllClose(a1, a2, check_dtypes=True)
|
||||
self.assertAllClose(a1, a2)
|
||||
|
||||
op = lambda x, index_array: x[index_array, ..., index_array[:, None], None]
|
||||
cop = api.jit(op)
|
||||
@ -663,7 +663,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
a1 = op(x, index_array)
|
||||
a2 = cop(x, index_array)
|
||||
|
||||
self.assertAllClose(a1, a2, check_dtypes=True)
|
||||
self.assertAllClose(a1, a2)
|
||||
|
||||
def testUnpacking(self):
|
||||
|
||||
@ -676,7 +676,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
a1 = foo(onp.arange(3))
|
||||
a2 = cfoo(onp.arange(3))
|
||||
|
||||
self.assertAllClose(a1, a2, check_dtypes=True)
|
||||
self.assertAllClose(a1, a2)
|
||||
|
||||
def testBooleanIndexingArray1D(self):
|
||||
idx = onp.array([True, True, False])
|
||||
@ -739,8 +739,8 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i)))
|
||||
expected = onp.broadcast_to(
|
||||
onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4))
|
||||
self.assertAllClose(expected, primals, check_dtypes=True)
|
||||
self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)
|
||||
self.assertAllClose(expected, primals)
|
||||
self.assertAllClose(onp.zeros_like(x), tangents)
|
||||
|
||||
def testTrivialGatherIsntGenerated(self):
|
||||
# https://github.com/google/jax/issues/1621
|
||||
@ -791,7 +791,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
|
||||
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
|
||||
array = jnp.ones(5)
|
||||
self.assertAllClose(array, array[:10], check_dtypes=True)
|
||||
self.assertAllClose(array, array[:10])
|
||||
|
||||
|
||||
def _broadcastable_shapes(shape):
|
||||
@ -877,8 +877,8 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker)
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
|
||||
@ -904,8 +904,8 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker)
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
@ -931,8 +931,8 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker)
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
|
@ -91,7 +91,8 @@ OpRecord = collections.namedtuple(
|
||||
"test_name", "check_dtypes", "tolerance", "inexact"])
|
||||
|
||||
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
||||
test_name=None, check_dtypes=True, tolerance=None, inexact=False):
|
||||
test_name=None, check_dtypes=True,
|
||||
tolerance=None, inexact=False):
|
||||
test_name = test_name or name
|
||||
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
||||
test_name, check_dtypes, tolerance, inexact)
|
||||
@ -497,7 +498,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
() in shapes)
|
||||
empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes)
|
||||
self._CompileAndCheck(
|
||||
fun, args_maker, check_dtypes=True, #not scalar_arg and not empty_shape,
|
||||
fun, args_maker, #not scalar_arg and not empty_shape,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
@ -525,7 +526,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
() in shapes)
|
||||
empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes)
|
||||
self._CompileAndCheck(
|
||||
fun, args_maker, check_dtypes=True, # not scalar_arg and not empty_shape,
|
||||
fun, args_maker, # not scalar_arg and not empty_shape,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -577,7 +578,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
|
||||
check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
@ -616,7 +617,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol,
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -646,8 +647,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}".format(
|
||||
@ -661,7 +662,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_fun = lambda x: jnp.count_nonzero(x, axis)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
@ -733,13 +734,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
try:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
except ValueError as e:
|
||||
if str(e) == "All-NaN slice encountered":
|
||||
self.skipTest("JAX doesn't support checking for all-NaN slices")
|
||||
else:
|
||||
raise
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": rec.test_name.capitalize(), "name": rec.name,
|
||||
@ -755,8 +756,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = partial(np_op, axis=0)
|
||||
jnp_fun = partial(jnp_op, axis=0)
|
||||
args_maker = lambda: [np.zeros((2, 0))]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}_{}".format(
|
||||
@ -793,9 +794,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15}
|
||||
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
|
||||
jtu.tolerance(rhs_dtype, tol_spec))
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol,
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -830,9 +831,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x
|
||||
y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y
|
||||
return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype))
|
||||
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp.dot, args_maker, check_dtypes=True, atol=tol,
|
||||
self._CompileAndCheck(jnp.dot, args_maker, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -866,10 +867,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np.complex128: 1e-12}
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol[np.float32] = tol[np.complex64] = 4e-2
|
||||
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker,
|
||||
check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(jnp.matmul, args_maker, check_dtypes=True, atol=tol,
|
||||
rtol=tol)
|
||||
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}_{}".format(
|
||||
@ -902,9 +901,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np.complex64: 1e-3, np.complex128: 1e-12}
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol[np.float32] = tol[np.complex64] = 2e-1
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testTensordotErrors(self):
|
||||
a = np.random.random((3, 2, 2))
|
||||
@ -941,8 +940,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
|
||||
jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert)
|
||||
np_fun = lambda e, t: np.isin(e, t, invert=invert)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -960,8 +959,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
|
||||
jnp_fun = lambda e, t: jnp.in1d(e, t, invert=invert)
|
||||
np_fun = lambda e, t: np.in1d(e, t, invert=invert)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1015,7 +1014,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
# TODO(phawkins): the promotion behavior changed in Numpy 1.17.
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testClipError(self):
|
||||
with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"):
|
||||
@ -1044,15 +1043,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
def testOperatorRound(self):
|
||||
self.assertAllClose(round(np.float32(7.532), 1),
|
||||
round(jnp.float32(7.5), 1), check_dtypes=True)
|
||||
round(jnp.float32(7.5), 1))
|
||||
self.assertAllClose(round(np.float32(1.234), 2),
|
||||
round(jnp.float32(1.234), 2), check_dtypes=True)
|
||||
round(jnp.float32(1.234), 2))
|
||||
self.assertAllClose(round(np.float32(1.234)),
|
||||
round(jnp.float32(1.234)), check_dtypes=False)
|
||||
self.assertAllClose(round(np.float32(7.532), 1),
|
||||
round(jnp.array(7.5, jnp.float32), 1), check_dtypes=True)
|
||||
round(jnp.array(7.5, jnp.float32), 1))
|
||||
self.assertAllClose(round(np.float32(1.234), 2),
|
||||
round(jnp.array(1.234, jnp.float32), 2), check_dtypes=True)
|
||||
round(jnp.array(1.234, jnp.float32), 2))
|
||||
self.assertAllClose(round(np.float32(1.234)),
|
||||
round(jnp.array(1.234, jnp.float32)),
|
||||
check_dtypes=False)
|
||||
@ -1098,7 +1097,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape=[{}]_reps={}".format(
|
||||
@ -1117,7 +1116,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
@ -1128,7 +1127,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testExtract(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}".format(
|
||||
@ -1150,7 +1149,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
np_fun = partial(np.compress, axis=axis)
|
||||
jnp_fun = partial(jnp.compress, axis=axis)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_condition=array[{}]_axis={}".format(
|
||||
@ -1166,7 +1165,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [np.array(condition), rng(shape, dtype)]
|
||||
np_fun = partial(np.compress, axis=axis)
|
||||
jnp_fun = partial(jnp.compress, axis=axis)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
|
||||
@ -1193,8 +1192,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
|
||||
@ -1221,8 +1220,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape=[{}]_axis={}_repeats={}".format(
|
||||
@ -1240,8 +1239,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_ind={}_inv={}_count={}".format(
|
||||
@ -1260,7 +1259,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
np_fun = lambda x: np.unique(x, return_index, return_inverse, return_counts)
|
||||
jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
def testIssue1233(self):
|
||||
'''
|
||||
@ -1273,10 +1272,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
lax_ans = jnp.repeat(m, repeats, axis)
|
||||
numpy_ans = np.repeat(m, repeats, axis)
|
||||
|
||||
self.assertAllClose(lax_ans, numpy_ans, check_dtypes=True, rtol=tol, atol=tol)
|
||||
self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol)
|
||||
|
||||
jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
m = jnp.array([1,2,3,4,5,6])
|
||||
args_maker = lambda: [m]
|
||||
@ -1313,8 +1312,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
attempt_sideeffect(np_input)
|
||||
attempt_sideeffect(jnp_input)
|
||||
|
||||
self.assertAllClose(np_input, expected_np_input_after_call, check_dtypes=True)
|
||||
self.assertAllClose(jnp_input, expected_jnp_input_after_call, check_dtypes=True)
|
||||
self.assertAllClose(np_input, expected_np_input_after_call)
|
||||
self.assertAllClose(jnp_input, expected_jnp_input_after_call)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format(
|
||||
@ -1339,7 +1338,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_fun = partial(jnp_op, mode=mode, precision=precision)
|
||||
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14}
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format(
|
||||
@ -1365,9 +1364,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
tol_thresholds = {dtypes.bfloat16: 4e-2}
|
||||
tol = max(jtu.tolerance(dtype, tol_thresholds),
|
||||
jtu.tolerance(out_dtype, tol_thresholds))
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1411,8 +1410,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype)
|
||||
jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype)
|
||||
args_maker = lambda: []
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_k={}".format(
|
||||
@ -1428,8 +1427,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda arg: getattr(np, op)(arg, k=k)
|
||||
jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_ndim={}_n={}".format(ndim, n),
|
||||
@ -1453,8 +1452,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda arg: np.diag(arg, k)
|
||||
jnp_fun = lambda arg: jnp.diag(arg, k)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format(
|
||||
@ -1472,8 +1471,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2)
|
||||
jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_n={}".format(np.dtype(dtype).name, n),
|
||||
@ -1484,8 +1483,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda: np.identity(n, dtype)
|
||||
jnp_fun = lambda: jnp.identity(n, dtype)
|
||||
args_maker = lambda: []
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_x1={}_x2={}_x1_rng={}".format(
|
||||
@ -1515,8 +1514,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_fun = lambda x1, x2: jnp.ldexp(x1, x2)
|
||||
args_maker = lambda: [x1_rng(x1_shape, x1_dtype),
|
||||
x2_rng(x2_shape, np.int32)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_x={}_rng_factory={}".format(
|
||||
@ -1539,8 +1538,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: np.frexp(x)
|
||||
jnp_fun = lambda x: jnp.frexp(x)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1565,8 +1564,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
return np.trace(arg, offset, axis1, axis2, out_dtype)
|
||||
jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_a={}_v={}_side={}".format(
|
||||
@ -1585,8 +1584,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [jnp.sort(rng(ashape, dtype)), rng(vshape, dtype)]
|
||||
np_fun = lambda a, v: np.searchsorted(a, v, side=side)
|
||||
jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_x={}_bins={}_right={}_reverse={}".format(
|
||||
@ -1607,8 +1606,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
|
||||
np_fun = lambda x, bins: np.digitize(x, bins, right=right)
|
||||
jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}".format(
|
||||
@ -1629,7 +1628,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(partial(np.stack, axis=axis))
|
||||
jnp_fun = partial(jnp.stack, axis=axis)
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1651,7 +1650,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(getattr(np, op))
|
||||
jnp_fun = getattr(jnp, op)
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outdtype={}".format(
|
||||
@ -1667,8 +1666,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype)
|
||||
jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype)
|
||||
args_maker = lambda: [rng((), fill_value_dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
@ -1684,8 +1683,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def args_maker(): return []
|
||||
np_op = partial(np_op, shape, dtype)
|
||||
jnp_op = partial(jnp_op, shape, dtype)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
def testOnesWithInvalidShape(self):
|
||||
with self.assertRaises(TypeError):
|
||||
@ -1708,8 +1707,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x, fill_value: np.full_like(x, fill_value, dtype=out_dtype)
|
||||
jnp_fun = lambda x, fill_value: jnp.full_like(x, fill_value, dtype=out_dtype)
|
||||
args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}_{}sections".format(
|
||||
@ -1725,8 +1724,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: np.split(x, num_sections, axis=axis)
|
||||
jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testSplitTypeError(self):
|
||||
# If we pass an ndarray for indices_or_sections -> no error
|
||||
@ -1774,7 +1773,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# linspace() compares poorly to numpy when using bfloat16
|
||||
if dtype != jnp.bfloat16:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
|
||||
self._CompileAndCheck(jnp_fun, args_maker,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1808,7 +1807,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if dtype != jnp.bfloat16:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}_{}sections".format(
|
||||
@ -1833,8 +1832,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: fn(np, axis)(x, num_sections)
|
||||
jnp_fun = lambda x: fn(jnp, axis)(x, num_sections)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}_order={}".format(
|
||||
@ -1860,8 +1859,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: np.reshape(x, out_shape, order=order)
|
||||
jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order)
|
||||
args_maker = lambda: [rng(arg_shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}".format(
|
||||
@ -1880,8 +1879,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: np.reshape(x, out_shape)
|
||||
jnp_fun = lambda x: x.reshape(*out_shape)
|
||||
args_maker = lambda: [rng(arg_shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_expanddim={!r}".format(
|
||||
@ -1898,11 +1897,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: np.expand_dims(x, dim)
|
||||
jnp_fun = lambda x: jnp.expand_dims(x, dim)
|
||||
args_maker = lambda: [rng(arg_shape, dtype)]
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
if isinstance(dim, tuple) and numpy_version < (1, 18, 0):
|
||||
raise SkipTest("support for multiple axes added in NumPy 1.18.0")
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_axes=({},{})".format(
|
||||
@ -1918,8 +1917,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: np.swapaxes(x, ax1, ax2)
|
||||
jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2)
|
||||
args_maker = lambda: [rng(arg_shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_axis={!r}".format(
|
||||
@ -1940,8 +1939,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: np.squeeze(x, ax)
|
||||
jnp_fun = lambda x: jnp.squeeze(x, ax)
|
||||
args_maker = lambda: [rng(arg_shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format(
|
||||
@ -2004,8 +2003,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
else:
|
||||
np_fun = partial(np.array, dtype=dtype)
|
||||
jnp_fun = jnp.array
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testArrayUnsupportedDtypeError(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
@ -2042,8 +2041,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
ans = jnp.array(bytearray(b'\x2a'))
|
||||
self.assertAllClose(
|
||||
ans,
|
||||
np.array([0x2a], dtype=np.uint8),
|
||||
check_dtypes=True)
|
||||
np.array([0x2a], dtype=np.uint8))
|
||||
|
||||
def testIsClose(self):
|
||||
c_isclose = api.jit(jnp.isclose)
|
||||
@ -2195,8 +2193,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
jnp_op = lambda x: jnp.flip(x, axis)
|
||||
np_op = lambda x: np.flip(x, axis)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
@ -2210,8 +2208,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
jnp_op = lambda x: jnp.flipud(x)
|
||||
np_op = lambda x: np.flipud(x)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -2226,8 +2224,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
jnp_op = lambda x: jnp.fliplr(x)
|
||||
np_op = lambda x: np.fliplr(x)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -2248,15 +2246,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
jnp_op = lambda x: jnp.rot90(x, k, axes)
|
||||
np_op = lambda x: np.rot90(x, k, axes)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
# TODO(mattjj): test infix operator overrides
|
||||
|
||||
def testRavel(self):
|
||||
rng = np.random.RandomState(0)
|
||||
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
|
||||
self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lambda x: x.ravel(), args_maker)
|
||||
|
||||
@parameterized.parameters(
|
||||
(0, (2, 1, 3)),
|
||||
@ -2265,12 +2263,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
([0, 1, 2], (2, 2)),
|
||||
([[[0, 1], [2, 3]]], (2, 2)))
|
||||
def testUnravelIndex(self, flat_index, shape):
|
||||
self._CheckAgainstNumpy(
|
||||
np.unravel_index,
|
||||
jnp.unravel_index,
|
||||
lambda: (flat_index, shape),
|
||||
check_dtypes=True
|
||||
)
|
||||
self._CheckAgainstNumpy(np.unravel_index, jnp.unravel_index,
|
||||
lambda: (flat_index, shape))
|
||||
|
||||
def testUnravelIndexOOB(self):
|
||||
self.assertEqual(jnp.unravel_index(2, (2,)), (1,))
|
||||
@ -2282,8 +2276,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
|
||||
np_op = lambda x: np.asarray(x).astype(jnp.int32)
|
||||
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int32)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_dtype={}".format(
|
||||
@ -2305,8 +2299,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_op = lambda x: jnp.asarray(x).view(dtype)
|
||||
# Above may produce signaling nans; ignore warnings from invalid values.
|
||||
with np.errstate(invalid='ignore'):
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
def testPathologicalFloats(self):
|
||||
args_maker = lambda: [np.array([
|
||||
@ -2325,8 +2319,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_op = lambda x: np.asarray(x).view('float32').view('uint32')
|
||||
jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32')
|
||||
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
# TODO(mattjj): test other ndarray-like method overrides
|
||||
|
||||
@ -2340,29 +2334,29 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# from https://github.com/google/jax/issues/145
|
||||
expected = np.arange(0.0, 1.0, 0.1, dtype=jnp.float_)
|
||||
ans = jnp.arange(0.0, 1.0, 0.1)
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
self.assertAllClose(expected, ans)
|
||||
|
||||
def testSortManually(self):
|
||||
# manual tests for sort are nice because we don't have to worry about ties.
|
||||
# lax.sort is tested combinatorially.
|
||||
ans = jnp.sort(np.array([16, 15, 23, 42, 8, 4]))
|
||||
expected = np.array([4, 8, 15, 16, 23, 42])
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
self.assertAllClose(expected, ans)
|
||||
|
||||
a = np.array([[1, 4], [3, 1]])
|
||||
ans = jnp.sort(a, axis=None)
|
||||
expected = np.array([1, 1, 3, 4])
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
self.assertAllClose(expected, ans)
|
||||
|
||||
a = np.array([[1, 4], [3, 1]])
|
||||
ans = jnp.sort(a) # last axis
|
||||
expected = np.array([[1, 4], [1, 3]])
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
self.assertAllClose(expected, ans)
|
||||
|
||||
a = np.array([[1, 4], [3, 1]])
|
||||
ans = jnp.sort(a, axis=0)
|
||||
expected = np.array([[1, 1], [3, 4]])
|
||||
self.assertAllClose(expected, ans, check_dtypes=True)
|
||||
self.assertAllClose(expected, ans)
|
||||
|
||||
def testArgsortManually(self):
|
||||
x = np.array([16, 15, 23, 42, 8, 4])
|
||||
@ -2394,8 +2388,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [np.random.randint(50, size=(5 ,5))]
|
||||
jnp_op = lambda x: jnp.msort(x)
|
||||
np_op = lambda x: np.msort(x)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_shifts={}_axis={}".format(
|
||||
@ -2420,8 +2414,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype), np.array(shifts)]
|
||||
jnp_op = partial(jnp.roll, axis=axis)
|
||||
np_op = partial(np.roll, axis=axis)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}_start={}".format(
|
||||
@ -2439,8 +2433,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
jnp_op = partial(jnp.rollaxis, axis=axis, start=start)
|
||||
np_op = partial(np.rollaxis, axis=axis, start=start)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}_bitorder={}".format(
|
||||
@ -2459,8 +2453,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
|
||||
np_op = partial(np.packbits, axis=axis, bitorder=bitorder)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}_bitorder={}_count={}".format(
|
||||
@ -2479,8 +2473,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
|
||||
np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_index={}_axis={}_mode={}".format(
|
||||
@ -2506,8 +2500,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rng_indices = jtu.rand_int(self.rng(), -5, 5)
|
||||
jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode)
|
||||
np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_ishape={}_axis={}".format(
|
||||
@ -2541,8 +2535,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
if hasattr(np, "take_along_axis"):
|
||||
np_op = lambda x, i: np.take_along_axis(x, i, axis=axis)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_n={}_increasing={}".format(
|
||||
@ -2606,9 +2600,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)
|
||||
for shape, dtype in zip(shapes, dtypes)]
|
||||
self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker,
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(jnp.ix_, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker)
|
||||
self._CompileAndCheck(jnp.ix_, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
@ -2630,8 +2623,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
dtype=dtype, sparse=sparse)
|
||||
jnp_fun = partial(jnp.indices, dimensions=dimensions,
|
||||
dtype=dtype, sparse=sparse)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -2681,7 +2674,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jtu.tolerance(q_dtype, tol_spec))
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -2714,7 +2707,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
tol = jtu.tolerance(a_dtype, tol_spec)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -2745,9 +2738,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng_factory(self.rng()), shapes, dtypes)
|
||||
def np_fun(cond, x, y):
|
||||
return _promote_like_jnp(partial(np.where, cond))(x, y)
|
||||
self._CheckAgainstNumpy(np_fun, jnp.where, args_maker,
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(jnp.where, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp.where, args_maker)
|
||||
self._CompileAndCheck(jnp.where, args_maker)
|
||||
|
||||
def testWhereScalarPromotion(self):
|
||||
x = jnp.where(jnp.array([True, False]), 3,
|
||||
@ -2782,7 +2774,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np.asarray(default, dtype=dtype))
|
||||
self._CheckAgainstNumpy(np_fun, jnp.select, args_maker,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(jnp.select, args_maker, check_dtypes=True,
|
||||
self._CompileAndCheck(jnp.select, args_maker,
|
||||
rtol={np.float64: 1e-7, np.complex128: 1e-7})
|
||||
|
||||
|
||||
@ -2823,7 +2815,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
a = np.arange(6) + 1
|
||||
ans = jnp.reshape(a, (3, 2), order='F')
|
||||
expected = np.reshape(a, (3, 2), order='F')
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}".format(op, pytype.__name__),
|
||||
@ -2836,8 +2828,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda arg: getattr(np, op)(arg).astype(dtype)
|
||||
jnp_fun = lambda arg: getattr(jnp, op)(arg)
|
||||
args_maker = lambda: [pytype(2)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{
|
||||
@ -2865,7 +2857,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_fun = partial(jnp.bincount, minlength=minlength, length=length)
|
||||
|
||||
if length is not None:
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
if length is None:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
|
||||
@ -2906,32 +2898,29 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
])))
|
||||
def testBlock(self, input):
|
||||
args_maker = lambda: [input]
|
||||
self._CheckAgainstNumpy(np.block, jnp.block, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp.block, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np.block, jnp.block, args_maker)
|
||||
self._CompileAndCheck(jnp.block, args_maker)
|
||||
|
||||
def testLongLong(self):
|
||||
self.assertAllClose(np.int64(7), api.jit(lambda x: x)(np.longlong(7)),
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(np.int64(7), api.jit(lambda x: x)(np.longlong(7)))
|
||||
|
||||
def testArange(self):
|
||||
# test cases inspired by dask tests at
|
||||
# https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92
|
||||
self.assertAllClose(jnp.arange(77),
|
||||
np.arange(77, dtype=jnp.int_), check_dtypes=True)
|
||||
np.arange(77, dtype=jnp.int_))
|
||||
self.assertAllClose(jnp.arange(2, 13),
|
||||
np.arange(2, 13, dtype=jnp.int_), check_dtypes=True)
|
||||
np.arange(2, 13, dtype=jnp.int_))
|
||||
self.assertAllClose(jnp.arange(4, 21, 9),
|
||||
np.arange(4, 21, 9, dtype=jnp.int_), check_dtypes=True)
|
||||
np.arange(4, 21, 9, dtype=jnp.int_))
|
||||
self.assertAllClose(jnp.arange(53, 5, -3),
|
||||
np.arange(53, 5, -3, dtype=jnp.int_),
|
||||
check_dtypes=True)
|
||||
np.arange(53, 5, -3, dtype=jnp.int_))
|
||||
self.assertAllClose(jnp.arange(77, dtype=float),
|
||||
np.arange(77, dtype=float), check_dtypes=True)
|
||||
np.arange(77, dtype=float))
|
||||
self.assertAllClose(jnp.arange(2, 13, dtype=int),
|
||||
np.arange(2, 13, dtype=int), check_dtypes=True)
|
||||
np.arange(2, 13, dtype=int))
|
||||
self.assertAllClose(jnp.arange(0, 1, -0.5),
|
||||
np.arange(0, 1, -0.5, dtype=jnp.float_),
|
||||
check_dtypes=True)
|
||||
np.arange(0, 1, -0.5, dtype=jnp.float_))
|
||||
|
||||
self.assertRaises(TypeError, lambda: jnp.arange())
|
||||
|
||||
@ -2976,8 +2965,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# argument.
|
||||
return lax.tie_in(y, 7.)
|
||||
|
||||
self.assertAllClose(np.zeros(3,), api.grad(f)(np.ones(3,)),
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(np.zeros(3,), api.grad(f)(np.ones(3,)))
|
||||
|
||||
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because this
|
||||
# is a numerical stability issue that should be solved with a custom jvp rule
|
||||
@ -2985,8 +2973,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# def testIssue777(self):
|
||||
# x = jnp.linspace(-200, 0, 4, dtype=np.float32)
|
||||
# f = api.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x))))
|
||||
# self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32),
|
||||
# check_dtypes=True)
|
||||
# self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
@ -3017,7 +3004,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
expected = np_op(x)
|
||||
actual = jnp_op(x)
|
||||
tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7})
|
||||
self.assertAllClose(expected, actual, check_dtypes=True, atol=tol,
|
||||
self.assertAllClose(expected, actual, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
def testIssue883(self):
|
||||
@ -3067,9 +3054,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
|
||||
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
|
||||
else:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol,
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -3100,9 +3087,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
|
||||
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
|
||||
else:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, rtol=tol,
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -3128,7 +3115,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
tol = jtu.join_tolerance(tol, jtu.tolerance(dtype))
|
||||
self._CheckAgainstNumpy(
|
||||
np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol,
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
def testIssue967(self):
|
||||
@ -3157,7 +3144,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(
|
||||
np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-2 if jtu.device_under_test() == "tpu" else None)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype),
|
||||
@ -3184,8 +3171,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
(None if begin_dtype is None else rng(begin_shape, begin_dtype))]
|
||||
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
|
||||
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testEDiff1dWithDtypeCast(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -3196,8 +3183,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)]
|
||||
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
|
||||
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
@ -3216,7 +3203,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
[dtype] * len(shapes))
|
||||
np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse)
|
||||
jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
@ -3272,7 +3259,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
endpoints = rng((2,), dtype)
|
||||
out = jnp.linspace(*endpoints, 10, dtype=dtype)
|
||||
self.assertAllClose(out[[0, -1]], endpoints, check_dtypes=True, rtol=0, atol=0)
|
||||
self.assertAllClose(out[[0, -1]], endpoints, rtol=0, atol=0)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
@ -3454,8 +3441,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32])
|
||||
np_op = lambda x: np.broadcast_to(x, to_shape)
|
||||
jnp_op = lambda x: jnp.broadcast_to(x, to_shape)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
def testBroadcastToIssue1522(self):
|
||||
self.assertRaisesRegex(
|
||||
@ -3540,7 +3527,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda y: np.gradient(y, *varargs, axis=axis)
|
||||
self._CheckAgainstNumpy(
|
||||
np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testZerosShapeErrors(self):
|
||||
# see https://github.com/google/jax/issues/1822
|
||||
@ -3556,9 +3543,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
def testTraceMethod(self):
|
||||
x = self.rng().randn(3, 4).astype(jnp.float_)
|
||||
self.assertAllClose(x.trace(), jnp.array(x).trace(), check_dtypes=True)
|
||||
self.assertAllClose(x.trace(), api.jit(lambda y: y.trace())(x),
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(x.trace(), jnp.array(x).trace())
|
||||
self.assertAllClose(x.trace(), api.jit(lambda y: y.trace())(x))
|
||||
|
||||
def testIntegerPowersArePrecise(self):
|
||||
# See https://github.com/google/jax/pull/3036
|
||||
|
@ -141,8 +141,8 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
return y
|
||||
|
||||
x = jnp.arange(3)
|
||||
self.assertAllClose(x, f('foo', x), check_dtypes=True)
|
||||
self.assertAllClose(x, jax.jit(f, 0)('foo', x), check_dtypes=True)
|
||||
self.assertAllClose(x, f('foo', x))
|
||||
self.assertAllClose(x, jax.jit(f, 0)('foo', x))
|
||||
|
||||
def test_exclude_second(self):
|
||||
|
||||
@ -153,8 +153,8 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
return x
|
||||
|
||||
x = jnp.arange(3)
|
||||
self.assertAllClose(x, f(x, 'foo'), check_dtypes=True)
|
||||
self.assertAllClose(x, jax.jit(f, 1)(x, 'foo'), check_dtypes=True)
|
||||
self.assertAllClose(x, f(x, 'foo'))
|
||||
self.assertAllClose(x, jax.jit(f, 1)(x, 'foo'))
|
||||
|
||||
def test_exclude_errors(self):
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -92,7 +92,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
partial(scipy_cg, M=M, maxiter=1),
|
||||
partial(lax_cg, M=M, maxiter=1),
|
||||
args_maker,
|
||||
check_dtypes=True,
|
||||
tol=1e-3)
|
||||
|
||||
# TODO(shoyer,mattjj): I had to loosen the tolerance for complex64[7,7]
|
||||
@ -101,14 +100,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
partial(scipy_cg, M=M, maxiter=3),
|
||||
partial(lax_cg, M=M, maxiter=3),
|
||||
args_maker,
|
||||
check_dtypes=True,
|
||||
tol=3e-3)
|
||||
|
||||
self._CheckAgainstNumpy(
|
||||
np.linalg.solve,
|
||||
partial(lax_cg, M=M, atol=1e-6),
|
||||
args_maker,
|
||||
check_dtypes=True,
|
||||
tol=2e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -125,10 +122,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
expected = np.linalg.solve(posify(a), b)
|
||||
actual = lax_cg(posify(a), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
actual = jit(lax_cg)(posify(a), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
# numerical gradients are only well defined if ``a`` is guaranteed to be
|
||||
# positive definite.
|
||||
@ -141,7 +138,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
b = jnp.arange(9.0).reshape((3, 3))
|
||||
expected = b / 2
|
||||
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
def test_cg_pytree(self):
|
||||
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
|
||||
|
@ -108,8 +108,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
@ -129,7 +129,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
args = args_maker()
|
||||
self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(lax_op, args_maker, check_dtypes=True, rtol=1e-5)
|
||||
self._CompileAndCheck(lax_op, args_maker, rtol=1e-5)
|
||||
|
||||
if test_autodiff:
|
||||
jtu.check_grads(lax_op, args, order=1,
|
||||
@ -153,14 +153,14 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
|
||||
tol={onp.float32: 1e-3, onp.float64: 1e-14})
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testIssue980(self):
|
||||
x = onp.full((4,), -1e20, dtype=onp.float32)
|
||||
self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
|
||||
lsp_special.expit(x), check_dtypes=True)
|
||||
lsp_special.expit(x))
|
||||
|
||||
def testXlogyShouldReturnZero(self):
|
||||
self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)
|
||||
|
@ -187,7 +187,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
||||
op = getattr(lax, op_name)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
@ -222,7 +222,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng((2, 3), from_dtype)]
|
||||
op = lambda x: lax.convert_element_type(x, to_dtype)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_from_dtype={}_to_dtype={}"
|
||||
@ -249,7 +249,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng((2, 3), from_dtype)]
|
||||
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_from_dtype={}_to_dtype={}"
|
||||
@ -284,7 +284,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
shapes = [min_shape, operand_shape, max_shape]
|
||||
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
||||
self._CompileAndCheck(lax.clamp, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax.clamp, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
|
||||
@ -325,7 +325,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
|
||||
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
||||
op = lambda *args: lax.concatenate(args, dim)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
||||
@ -368,7 +368,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def fun(lhs, rhs):
|
||||
return lax.conv(lhs, rhs, strides, padding)
|
||||
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -419,7 +419,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
return lax.conv_with_general_padding(
|
||||
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
||||
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
||||
@ -502,7 +502,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
dimension_numbers, feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
# TODO(mattjj): test conv_general_dilated against numpy
|
||||
|
||||
@ -512,7 +512,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
return [rng((10, 5), onp.float32), rng((5, 7), onp.float32)]
|
||||
jnp_fun = partial(lax.conv_general_dilated, window_strides=(),
|
||||
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(jnp_fun, onp.dot, args_maker, tol=.1)
|
||||
|
||||
|
||||
@ -685,8 +685,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testDot(self, lhs_shape, rhs_shape, dtype, precision, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker,
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
|
||||
@ -798,7 +797,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
op = lambda x: lax.broadcast(x, broadcast_sizes)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_broadcast_sizes={}".format(
|
||||
@ -835,7 +834,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(inshape, dtype)]
|
||||
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
||||
@ -921,7 +920,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(arg_shape, onp.float32)]
|
||||
op = lambda x: lax.squeeze(x, dimensions)
|
||||
numpy_op = lambda x: lax_reference.squeeze(x, dimensions)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
||||
check_grads(op, args_maker(), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@ -940,7 +939,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(arg_shape, dtype)]
|
||||
op = lambda x: lax.reshape(x, out_shape)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}".format(
|
||||
@ -971,7 +970,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_pads={}"
|
||||
@ -1018,7 +1017,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
return [rng(pred_shape, onp.bool_), rng(arg_shape, arg_dtype),
|
||||
rng(arg_shape, arg_dtype)]
|
||||
rng = rng_factory(self.rng())
|
||||
return self._CompileAndCheck(lax.select, args_maker, check_dtypes=True)
|
||||
return self._CompileAndCheck(lax.select, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_predshape={}_argshapes={}".format(
|
||||
@ -1061,7 +1060,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
op = lambda x: lax.slice(x, starts, limits, strides)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -1109,7 +1108,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)]
|
||||
op = lambda x, starts: lax.dynamic_slice(x, starts, size_indices)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
|
||||
@ -1159,8 +1158,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
return [rng(shape, dtype), rng(update_shape, dtype),
|
||||
onp.array(start_indices)]
|
||||
|
||||
self._CompileAndCheck(lax.dynamic_update_slice, args_maker,
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(lax.dynamic_update_slice, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
|
||||
@ -1202,7 +1200,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
op = lambda x: lax.transpose(x, perm)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_perm={}".format(
|
||||
@ -1257,13 +1255,13 @@ class LaxTest(jtu.JaxTestCase):
|
||||
init_val = onp.asarray(init_val, dtype=dtype)
|
||||
fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims)
|
||||
args_maker = lambda: [rng(shape, dtype), init_val]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
# we separately test the version that uses a concrete init_val because it
|
||||
# can hit different code paths
|
||||
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}_padding={}"
|
||||
@ -1297,7 +1295,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# pylint: disable=cell-var-from-loop
|
||||
for shape, dims, strides in all_configs:
|
||||
args_maker = lambda: [rng(shape, dtype), init_val]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
# we separately test the version that uses a concrete init_val because it
|
||||
@ -1308,7 +1306,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# pylint: disable=cell-var-from-loop
|
||||
for shape, dims, strides in all_configs:
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1329,9 +1327,9 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testCumulativeReduce(self, op, onp_op, shape, dtype, axis, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
fun = partial(op, axis=axis)
|
||||
onp_fun = partial(onp_op, axis=axis)
|
||||
onp_fun = partial(onp_op, axis=axis, dtype=dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
self._CheckAgainstNumpy(fun, onp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1350,7 +1348,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
fun = lambda x: lax.sort(x, dimension=axis)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}".format(
|
||||
@ -1399,7 +1397,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
return keys, values
|
||||
|
||||
fun = lambda keys, values: lax.sort_key_val(keys, values, axis)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
|
||||
@ -1447,12 +1445,12 @@ class LaxTest(jtu.JaxTestCase):
|
||||
values = self.rng().permutation(flat_values).reshape(shape)
|
||||
return [values]
|
||||
def reference_top_k(x):
|
||||
bcast_idxs = onp.broadcast_to(onp.arange(shape[-1]), shape)
|
||||
bcast_idxs = onp.broadcast_to(onp.arange(shape[-1], dtype=onp.int32), shape)
|
||||
sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs)
|
||||
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
|
||||
op = lambda vs: lax.top_k(vs, k=k)
|
||||
self._CheckAgainstNumpy(op, reference_top_k, args_maker)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
|
||||
@ -1468,7 +1466,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testBatchMatMul(self, lhs_shape, rhs_shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
arg_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
self._CompileAndCheck(lax.batch_matmul, arg_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax.batch_matmul, arg_maker)
|
||||
|
||||
def testCollapse(self):
|
||||
|
||||
@ -1498,7 +1496,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rand_idxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs)
|
||||
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
|
||||
fun = lambda src, idxs: lax.index_take(src, idxs, axes)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
|
||||
@ -1531,7 +1529,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
||||
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
|
||||
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
@ -1562,7 +1560,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
||||
rng(update_shape, dtype)]
|
||||
fun = partial(lax.scatter_add, dimension_numbers=dnums)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
@ -1593,7 +1591,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
||||
rng(update_shape, dtype)]
|
||||
fun = partial(lax.scatter_min, dimension_numbers=dnums)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
@ -1624,7 +1622,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
||||
rng(update_shape, dtype)]
|
||||
fun = partial(lax.scatter_max, dimension_numbers=dnums)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
@ -1655,7 +1653,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
||||
rng(update_shape, dtype)]
|
||||
fun = partial(lax.scatter, dimension_numbers=dnums)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
def testIssue831(self):
|
||||
# Tests the DeviceTuple constant handler
|
||||
@ -1667,7 +1665,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
def testReshapeWithUnusualShapes(self):
|
||||
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
|
||||
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
|
||||
self.assertAllClose(ans, onp.ones((3, 1), onp.float32))
|
||||
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
@ -1722,9 +1720,9 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
jit_result = api.jit(lambda x: lax.add(x, make_const()))(zero)
|
||||
|
||||
# ensure they're all the same
|
||||
self.assertAllClose(asarray_result, expected, check_dtypes=True)
|
||||
self.assertAllClose(argument_result, expected, check_dtypes=True)
|
||||
self.assertAllClose(jit_result, expected, check_dtypes=True)
|
||||
self.assertAllClose(asarray_result, expected)
|
||||
self.assertAllClose(argument_result, expected)
|
||||
self.assertAllClose(jit_result, expected)
|
||||
|
||||
# ensure repr doesn't crash
|
||||
repr(make_const())
|
||||
@ -2748,11 +2746,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
x = 3.14
|
||||
ans = api.grad(f)(x)
|
||||
expected = api.grad(f2)(x, x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
ans = api.grad(api.grad(f))(x)
|
||||
expected = api.grad(api.grad(f2))(x, x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
|
||||
expected = onp.array(0.0)
|
||||
@ -2784,8 +2782,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
|
||||
return 1 / x
|
||||
grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv))))))
|
||||
self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)),
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)))
|
||||
|
||||
def all_bdims(*shapes):
|
||||
bdims = (itertools.chain([cast(Optional[int], None)],
|
||||
@ -2819,7 +2816,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
args_slice = args_slicer(args, bdims)
|
||||
ans = api.vmap(op, bdims)(*args)
|
||||
expected = onp.stack([op(*args_slice(i)) for i in range(bdim_size)])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
|
@ -71,8 +71,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self.skipTest("Unimplemented case for complex Cholesky decomposition.")
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.cholesky, args_maker, check_dtypes=True)
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.cholesky, args_maker)
|
||||
|
||||
if jnp.finfo(dtype).bits == 64:
|
||||
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)
|
||||
@ -96,14 +96,13 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.det, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.det, args_maker,
|
||||
rtol={np.float64: 1e-13, np.complex128: 1e-13})
|
||||
|
||||
def testDetOfSingularMatrix(self):
|
||||
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
||||
self.assertAllClose(np.float32(0), jsp.linalg.det(x), check_dtypes=True)
|
||||
self.assertAllClose(np.float32(0), jsp.linalg.det(x))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -177,10 +176,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.tensorsolve,
|
||||
jnp.linalg.tensorsolve, args_maker,
|
||||
check_dtypes=True,
|
||||
tol={np.float32: 1e-2, np.float64: 1e-3})
|
||||
self._CompileAndCheck(jnp.linalg.tensorsolve,
|
||||
args_maker, check_dtypes=True,
|
||||
args_maker,
|
||||
rtol={np.float64: 1e-13})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -198,8 +196,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.slogdet, args_maker, check_dtypes=True)
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.slogdet, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -221,7 +219,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
mat = jnp.array([np.diag(np.ones([5], dtype=np.float32))*(-.01)] * 2)
|
||||
args_maker = lambda: [mat]
|
||||
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
tol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
@ -249,7 +247,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self.assertTrue(np.all(norm(np.matmul(a, v) - w[..., None, :] * v) < 100))
|
||||
|
||||
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
|
||||
check_dtypes=True, rtol=1e-3)
|
||||
rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
@ -269,7 +267,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
a, = args_maker()
|
||||
w1, _ = jnp.linalg.eig(a)
|
||||
w2 = jnp.linalg.eigvals(a)
|
||||
self.assertAllClose(w1, w2, check_dtypes=True)
|
||||
self.assertAllClose(w1, w2)
|
||||
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testEigvalsInf(self):
|
||||
@ -328,7 +326,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self.assertTrue(norm(np.matmul(a, v) - w * v) < tol)
|
||||
|
||||
self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker,
|
||||
check_dtypes=True, rtol=1e-3)
|
||||
rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
@ -349,7 +347,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
a = (a + np.conj(a.T)) / 2
|
||||
return [a]
|
||||
self._CheckAgainstNumpy(np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
tol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -478,9 +476,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
||||
np_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
||||
self._CheckAgainstNumpy(np_fn, np_fn, args_maker,
|
||||
check_dtypes=False, tol=1e-3)
|
||||
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fn, np_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(np_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format(
|
||||
@ -533,7 +531,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False), np.asarray(out), atol=1e-4, rtol=1e-4))
|
||||
|
||||
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
|
||||
args_maker, check_dtypes=True)
|
||||
args_maker)
|
||||
if not (compute_uv and full_matrices):
|
||||
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
@ -698,8 +696,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.solve, args_maker, check_dtypes=True)
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.solve, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -726,8 +724,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
return [a]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.inv, jnp.linalg.inv, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.inv, args_maker, check_dtypes=True)
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.inv, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -743,8 +741,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
|
||||
check_dtypes=True, tol=1e-2)
|
||||
self._CompileAndCheck(jnp.linalg.pinv, args_maker, check_dtypes=True)
|
||||
tol=1e-2)
|
||||
self._CompileAndCheck(jnp.linalg.pinv, args_maker)
|
||||
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
|
||||
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
|
||||
|
||||
@ -754,15 +752,13 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2)
|
||||
return jnp.linalg.pinv(a)
|
||||
j = jax.jacobian(f)(jnp.float32(2.))
|
||||
self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j,
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j)
|
||||
|
||||
expected = jnp.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
|
||||
[[[0., 0.], [-1., 0.]], [[0., 0.], [0., -1.]]]],
|
||||
dtype=jnp.float32)
|
||||
self.assertAllClose(
|
||||
expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)),
|
||||
check_dtypes=True)
|
||||
expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_n={}".format(
|
||||
@ -781,9 +777,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
|
||||
self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n),
|
||||
partial(jnp.linalg.matrix_power, n=n),
|
||||
args_maker, check_dtypes=True, tol=tol)
|
||||
args_maker, tol=tol)
|
||||
self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker,
|
||||
check_dtypes=True, rtol=1e-3)
|
||||
rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
@ -825,9 +821,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
tol = {np.float32: 1e-4, np.float64: 1e-10,
|
||||
np.complex64: 1e-4, np.complex128: 1e-10}
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -865,7 +860,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
return [lhs, rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
# Disabled because grad is flaky for low-rank inputs.
|
||||
# TODO:
|
||||
@ -880,7 +875,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
grad_test_jc = jit(grad(jit(test)))
|
||||
xc = np.eye(3, dtype=np.complex)
|
||||
self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)
|
||||
self.assertAllClose(xc, grad_test_jc(xc))
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testIssue1151(self):
|
||||
@ -888,8 +883,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)
|
||||
b = jnp.array(rng.randn(100, 3), dtype=jnp.float32)
|
||||
x = jnp.linalg.solve(A, b)
|
||||
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2,
|
||||
check_dtypes=True)
|
||||
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2)
|
||||
jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A, b)
|
||||
jac1 = jax.jacobian(jnp.linalg.solve, argnums=1)(A, b)
|
||||
jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A[0], b[0])
|
||||
@ -926,8 +920,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def testBlockDiag(self, args):
|
||||
args_maker = lambda: args
|
||||
self._CheckAgainstNumpy(osp.linalg.block_diag, jsp.linalg.block_diag,
|
||||
args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jsp.linalg.block_diag, args_maker, check_dtypes=True)
|
||||
args_maker)
|
||||
self._CompileAndCheck(jsp.linalg.block_diag, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -943,15 +937,15 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
x, = args_maker()
|
||||
p, l, u = jsp.linalg.lu(x)
|
||||
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True,
|
||||
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)),
|
||||
rtol={np.float32: 1e-3, np.float64: 1e-12,
|
||||
np.complex64: 1e-3, np.complex128: 1e-12})
|
||||
self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jsp.linalg.lu, args_maker)
|
||||
|
||||
def testLuOfSingularMatrix(self):
|
||||
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
||||
p, l, u = jsp.linalg.lu(x)
|
||||
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True)
|
||||
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -986,9 +980,9 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
us = np.stack([out[2] for out in expected])
|
||||
|
||||
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args))
|
||||
self.assertAllClose(ps, actual_ps, check_dtypes=True)
|
||||
self.assertAllClose(ls, actual_ls, check_dtypes=True)
|
||||
self.assertAllClose(us, actual_us, check_dtypes=True)
|
||||
self.assertAllClose(ps, actual_ps)
|
||||
self.assertAllClose(ls, actual_ls)
|
||||
self.assertAllClose(us, actual_us)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -1008,9 +1002,9 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
u = np.triu(lu)
|
||||
for i in range(n):
|
||||
x[[i, piv[i]],] = x[[piv[i], i],]
|
||||
self.assertAllClose(x, np.matmul(l, u), check_dtypes=True, rtol=1e-3,
|
||||
self.assertAllClose(x, np.matmul(l, u), rtol=1e-3,
|
||||
atol=1e-3)
|
||||
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -1040,9 +1034,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
lu, piv = osp.linalg.lu_factor(a)
|
||||
return [lu, piv, rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -1081,9 +1074,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
a = np.tril(a) if lower else np.triu(a)
|
||||
return [a, rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -1132,7 +1124,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower,
|
||||
unit_diagonal=unit_diagonal)
|
||||
|
||||
self.assertAllClose(np_ans, ans, check_dtypes=True,
|
||||
self.assertAllClose(np_ans, ans,
|
||||
rtol={np.float32: 1e-4, np.float64: 1e-11})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1228,15 +1220,13 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
osp_fun = lambda a: osp.linalg.expm(a)
|
||||
jsp_fun = lambda a: jsp.linalg.expm(a)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
args_maker_triu = lambda: [np.triu(rng((n, n), dtype))]
|
||||
jsp_fun_triu = lambda a: jsp.linalg.expm(a,upper_triangular=True)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu,
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(jsp_fun_triu, args_maker_triu, check_dtypes=True)
|
||||
jsp_fun_triu = lambda a: jsp.linalg.expm(a, upper_triangular=True)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu)
|
||||
self._CompileAndCheck(jsp_fun_triu, args_maker_triu)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -1249,9 +1239,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
args_maker_zeros = lambda: [np.zeros((n, n), dtype)]
|
||||
osp_fun = lambda a: osp.linalg.expm(a)
|
||||
jsp_fun = lambda a: jsp.linalg.expm(a)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros,
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(jsp_fun, args_maker_zeros, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros)
|
||||
self._CompileAndCheck(jsp_fun, args_maker_zeros)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs={}_rhs={}_lower={}".format(
|
||||
@ -1280,7 +1269,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
U = np.triu(rng(lhs_shape, dtype))
|
||||
return [(U, lower), b]
|
||||
self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve,
|
||||
args_maker, check_dtypes=True, tol=1e-3)
|
||||
args_maker, tol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -35,7 +35,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
with loops.Scope() as s:
|
||||
s.x = r + 1
|
||||
return s.x
|
||||
self.assertAllClose(4.0, f_op(3.), check_dtypes=True)
|
||||
self.assertAllClose(4.0, f_op(3.))
|
||||
|
||||
def test_loop_empty(self):
|
||||
def f_op(r):
|
||||
@ -44,7 +44,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
pass
|
||||
return r
|
||||
|
||||
self.assertAllClose(3.0, f_op(3.), check_dtypes=True)
|
||||
self.assertAllClose(3.0, f_op(3.))
|
||||
|
||||
def test_loop_1(self):
|
||||
"""One loop with one state var, with transforms."""
|
||||
@ -56,14 +56,14 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
return s.out
|
||||
def f_expected(inc):
|
||||
return 10 + 5 * inc
|
||||
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(2.), f_op(2.))
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.))
|
||||
self.assertAllClose(5., api.grad(f_op)(2.))
|
||||
self.assertAllClose(5., api.grad(f_op)(2.))
|
||||
inc_batch = np.arange(5, dtype=jnp.float_)
|
||||
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch],
|
||||
dtype=jnp.float_),
|
||||
api.vmap(f_op)(inc_batch), check_dtypes=True)
|
||||
api.vmap(f_op)(inc_batch))
|
||||
|
||||
|
||||
def test_loop_2(self):
|
||||
@ -77,7 +77,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
s.out2 += 1.
|
||||
return (s.out1, s.out2)
|
||||
|
||||
self.assertAllClose((10. + 2. * 5, 20. + 1. * 5), f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose((10. + 2. * 5, 20. + 1. * 5), f_op(2.))
|
||||
|
||||
|
||||
def test_add_vectors(self):
|
||||
@ -92,7 +92,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
|
||||
x = jnp.array([1., 2., 3.], dtype=jnp.float32)
|
||||
y = jnp.array([4., 5., 6.], dtype=jnp.float32)
|
||||
self.assertAllClose(jnp.add(x, y), add_vec(x, y), check_dtypes=True)
|
||||
self.assertAllClose(jnp.add(x, y), add_vec(x, y))
|
||||
|
||||
def test_matmul(self):
|
||||
def matmul(x, y):
|
||||
@ -109,7 +109,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
|
||||
x = jnp.array([[1., 2., 3.]], dtype=jnp.float32) # 1x3
|
||||
y = jnp.array([[4.], [5.], [6.]], dtype=jnp.float32) # 3x1
|
||||
self.assertAllClose(jnp.matmul(x, y), matmul(x, y), check_dtypes=True)
|
||||
self.assertAllClose(jnp.matmul(x, y), matmul(x, y))
|
||||
|
||||
def test_reuse_range(self):
|
||||
"""Ranges can be reused, as long as not nested in each other."""
|
||||
@ -136,7 +136,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(10. + 5 * (2. + 6 * 2.), f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(10. + 5 * (2. + 6 * 2.), f_op(2.))
|
||||
|
||||
def test_example_doc(self):
|
||||
"The example from the module docstring."
|
||||
@ -169,8 +169,8 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)
|
||||
return s.arr
|
||||
|
||||
self.assertAllClose(f_expected(), f_op_jax(), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(), f_op_loops(), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(), f_op_jax())
|
||||
self.assertAllClose(f_expected(), f_op_loops())
|
||||
|
||||
def test_loop_mutable_used_but_not_changed(self):
|
||||
def f_op(inc):
|
||||
@ -184,7 +184,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
|
||||
return save_to_other_var
|
||||
|
||||
self.assertAllClose(10. + 5 * 2., f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(10. + 5 * 2., f_op(2.))
|
||||
|
||||
def test_range_locations(self):
|
||||
"""Ranges have locations."""
|
||||
@ -257,7 +257,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
pass
|
||||
return i
|
||||
|
||||
self.assertAllClose(4, f_op(4), check_dtypes=True)
|
||||
self.assertAllClose(4, f_op(4))
|
||||
|
||||
def test_error_new_state_in_loop(self):
|
||||
"""Error when creating new state in a loop."""
|
||||
@ -281,13 +281,13 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(16., f_op(0, 4, 4.), check_dtypes=True)
|
||||
self.assertAllClose(16., f_op(0, 4, 4.))
|
||||
# Ok to jit, as long as the start and end are static
|
||||
self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.), check_dtypes=True)
|
||||
self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.))
|
||||
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
|
||||
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.), check_dtypes=True)
|
||||
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.))
|
||||
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
|
||||
self.assertAllClose(16., api.vmap(f_op)(jnp.zeros(10), jnp.ones(10), jnp.array([4.] * 10)), check_dtypes=True)
|
||||
self.assertAllClose(16., api.vmap(f_op)(jnp.zeros(10), jnp.ones(10), jnp.array([4.] * 10)))
|
||||
|
||||
def test_cond(self):
|
||||
def f_op(inc):
|
||||
@ -297,8 +297,8 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(10. + 2., f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(10., f_op(-2.), check_dtypes=True)
|
||||
self.assertAllClose(10. + 2., f_op(2.))
|
||||
self.assertAllClose(10., f_op(-2.))
|
||||
|
||||
def test_cond_state(self):
|
||||
"""Conditionals predicated on scope fields."""
|
||||
@ -309,8 +309,8 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
s.out *= 2.
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(2. * 2., f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(-2., f_op(-2.), check_dtypes=True)
|
||||
self.assertAllClose(2. * 2., f_op(2.))
|
||||
self.assertAllClose(-2., f_op(-2.))
|
||||
|
||||
def test_cond_nested(self):
|
||||
"""Nested conditionals."""
|
||||
@ -341,7 +341,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
return s.out
|
||||
|
||||
for init in [-1., 0., 9., 10.]:
|
||||
self.assertAllClose(f_expected(init), f_op(init), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(init), f_op(init))
|
||||
|
||||
|
||||
def test_error_cond_using_index_var(self):
|
||||
@ -373,13 +373,13 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
out += 1.
|
||||
return out
|
||||
|
||||
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(2.), f_op(2.))
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.))
|
||||
self.assertAllClose(f_expected(1.), f_op(1.))
|
||||
init_batch = np.array([1., 2., 3.], dtype=np.float32)
|
||||
self.assertAllClose(np.array([f_expected(init) for init in init_batch],
|
||||
dtype=np.float32),
|
||||
api.vmap(f_op)(init_batch), check_dtypes=True)
|
||||
api.vmap(f_op)(init_batch))
|
||||
|
||||
def test_error_while_cond_mutation(self):
|
||||
"""Disallow mutation in the while conditional."""
|
||||
|
@ -601,8 +601,8 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_slice_oob_indexing(self):
|
||||
# https://github.com/google/jax/issues/2245
|
||||
self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10], check_dtypes=True)
|
||||
self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:], check_dtypes=True)
|
||||
self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10])
|
||||
self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:])
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
@ -51,7 +51,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
y = npr.uniform(size=(10,10))
|
||||
z_host = np.matmul(x, y)
|
||||
z = fun(x, y)
|
||||
self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2)
|
||||
self.assertAllClose(z, z_host, rtol=1e-2)
|
||||
correct_platform = backend if backend else jtu.device_under_test()
|
||||
self.assertEqual(z.device_buffer.platform(), correct_platform)
|
||||
|
||||
@ -73,7 +73,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
y = npr.uniform(size=(10,10))
|
||||
z_host = np.matmul(x, y) + np.ones_like(x)
|
||||
z = fun(x, y)
|
||||
self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2)
|
||||
self.assertAllClose(z, z_host, rtol=1e-2)
|
||||
correct_platform = outer if outer else jtu.device_under_test()
|
||||
self.assertEqual(z.device_buffer.platform(), correct_platform)
|
||||
|
||||
|
@ -55,7 +55,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
def testSoftplusGradInf(self):
|
||||
self.assertAllClose(
|
||||
1., jax.grad(nn.softplus)(float('inf')), check_dtypes=True)
|
||||
1., jax.grad(nn.softplus)(float('inf')))
|
||||
|
||||
def testSoftplusGradNegInf(self):
|
||||
check_grads(nn.softplus, (-float('inf'),), order=1,
|
||||
@ -91,7 +91,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
def testGluValue(self):
|
||||
val = nn.glu(jnp.array([1.0, 0.0]))
|
||||
self.assertAllClose(val, jnp.array([0.5]), check_dtypes=True)
|
||||
self.assertAllClose(val, jnp.array([0.5]))
|
||||
|
||||
@parameterized.parameters(*itertools.product(
|
||||
(jnp.float32, jnp.bfloat16, jnp.float16),
|
||||
@ -120,33 +120,33 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
expected = jnp.array([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]])
|
||||
self.assertAllClose(actual, expected, check_dtypes=True)
|
||||
self.assertAllClose(actual, expected)
|
||||
|
||||
actual = nn.one_hot(jnp.array([1, 2, 0]), 3)
|
||||
expected = jnp.array([[0., 1., 0.],
|
||||
[0., 0., 1.],
|
||||
[1., 0., 0.]])
|
||||
self.assertAllClose(actual, expected, check_dtypes=True)
|
||||
self.assertAllClose(actual, expected)
|
||||
|
||||
def testOneHotOutOfBound(self):
|
||||
actual = nn.one_hot(jnp.array([-1, 3]), 3)
|
||||
expected = jnp.array([[0., 0., 0.],
|
||||
[0., 0., 0.]])
|
||||
self.assertAllClose(actual, expected, check_dtypes=True)
|
||||
self.assertAllClose(actual, expected)
|
||||
|
||||
def testOneHotNonArrayInput(self):
|
||||
actual = nn.one_hot([0, 1, 2], 3)
|
||||
expected = jnp.array([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]])
|
||||
self.assertAllClose(actual, expected, check_dtypes=True)
|
||||
self.assertAllClose(actual, expected)
|
||||
|
||||
def testOneHotCustomDtype(self):
|
||||
actual = nn.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
|
||||
expected = jnp.array([[True, False, False],
|
||||
[False, True, False],
|
||||
[False, False, True]])
|
||||
self.assertAllClose(actual, expected, check_dtypes=True)
|
||||
self.assertAllClose(actual, expected)
|
||||
|
||||
InitializerRecord = collections.namedtuple(
|
||||
"InitializerRecord",
|
||||
|
@ -39,7 +39,7 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
def _CheckFuns(self, optimizer, loss, x0, *args):
|
||||
init_fun, update_fun, get_params = optimizer(*args)
|
||||
opt_state = init_fun(x0)
|
||||
self.assertAllClose(x0, get_params(opt_state), check_dtypes=True)
|
||||
self.assertAllClose(x0, get_params(opt_state))
|
||||
opt_state2 = update_fun(0, grad(loss)(x0), opt_state) # doesn't crash
|
||||
self.assertEqual(tree_util.tree_structure(opt_state),
|
||||
tree_util.tree_structure(opt_state2))
|
||||
@ -294,7 +294,7 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
|
||||
J1 = jacrev(loss, argnums=(0,))(initial_params)
|
||||
J2 = jacfwd(loss, argnums=(0,))(initial_params)
|
||||
self.assertAllClose(J1, J2, check_dtypes=True, rtol=1e-6)
|
||||
self.assertAllClose(J1, J2, rtol=1e-6)
|
||||
|
||||
def testUnpackPackRoundTrip(self):
|
||||
opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)
|
||||
|
@ -87,7 +87,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
t = np.ones((5, 3))
|
||||
ans = soft_pmap(*_papply(fun))(t)
|
||||
expected = fun(t)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testLogSoftmax(self):
|
||||
raise SkipTest("test doesn't pass yet") # TODO(frostig)
|
||||
@ -113,7 +113,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
|
||||
pfun, axis_name = _papply(jnp.add)
|
||||
ans = soft_pmap(pfun, axis_name)(x, x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testAddBroadcasting(self):
|
||||
raise SkipTest("test doesn't pass yet") # TODO(frostig)
|
||||
@ -126,7 +126,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
|
||||
pfun, axis_name = _papply(fun)
|
||||
ans = soft_pmap(pfun, axis_name)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testMakeJaxprPapplyComposition(self):
|
||||
raise SkipTest( # TODO(mattjj)
|
||||
@ -142,7 +142,7 @@ class ParallelizeTest(jtu.JaxTestCase):
|
||||
def dedup(self, arr, expected_rank):
|
||||
if arr.ndim == expected_rank + 1:
|
||||
for i in range(arr.shape[0] - 1):
|
||||
self.assertAllClose(arr[i], arr[i + 1], check_dtypes=True)
|
||||
self.assertAllClose(arr[i], arr[i + 1])
|
||||
return arr[0]
|
||||
else:
|
||||
assert arr.ndim == expected_rank
|
||||
|
@ -217,7 +217,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
f_expected = np.broadcast_to(x, mesh_shape)
|
||||
f_ans = f(x, y)
|
||||
self.assertAllClose(f_ans, f_expected, check_dtypes=True)
|
||||
self.assertAllClose(f_ans, f_expected)
|
||||
self.assertIsInstance(f_ans, pxla.ShardedDeviceArray)
|
||||
# the output is actually replicated (has the same values in each device buffer)
|
||||
# but out_axes is implicitly 0, so we shouldn't have replication in the
|
||||
@ -226,7 +226,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
g_expected = np.broadcast_to(x - np.sum(y, 0, keepdims=True), shape)
|
||||
g_ans = g(x, y)
|
||||
self.assertAllClose(g_ans, g_expected, check_dtypes=True)
|
||||
self.assertAllClose(g_ans, g_expected)
|
||||
self.assertIsInstance(g_ans, pxla.ShardedDeviceArray)
|
||||
self.assertEqual(g_ans.sharding_spec.replication_factor, 1)
|
||||
|
||||
@ -302,7 +302,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
ans = grad(lambda x: jnp.sum(splitjvp(x)))(x)
|
||||
expected = grad(fun)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testTwoArgsGrad(self):
|
||||
def f(x, y):
|
||||
@ -352,7 +352,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
ans = grad(lambda x: jnp.sum(test_fun(x)))(x)
|
||||
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True, atol=1e-3)
|
||||
self.assertAllClose(ans, expected, atol=1e-3)
|
||||
|
||||
def testShardedDeviceArrays(self):
|
||||
f = lambda x: 2 * x
|
||||
@ -466,11 +466,11 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 1)
|
||||
expected = x - expected_psum
|
||||
ans = f1(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
expected = x - expected_psum + 1.
|
||||
ans = f2(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
shape = (replicas // 2, 2, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
@ -482,7 +482,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
|
||||
expected = x - expected_psum
|
||||
ans = f3(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testAxisGroups(self):
|
||||
axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2))
|
||||
@ -568,7 +568,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
lambda x: lax.ppermute(x, "i", zip(range(num_devices), perm)), "i")
|
||||
result = f(jnp.arange(num_devices, dtype=jnp.float32))
|
||||
expected = jnp.asarray(perm, dtype=jnp.float32)
|
||||
self.assertAllClose(result, expected, check_dtypes=True)
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
@jtu.skip_on_devices("cpu", "gpu")
|
||||
def testRule30(self):
|
||||
@ -911,7 +911,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs)
|
||||
|
||||
r = pmap(lambda x: x + 1)(arr)
|
||||
self.assertAllClose(r, arr + 1, check_dtypes=True)
|
||||
self.assertAllClose(r, arr + 1)
|
||||
self.assertEqual(len(r.device_buffers), 6)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ -1114,8 +1114,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
vals = list(range(500))
|
||||
ndevices = xla_bridge.device_count()
|
||||
self.assertAllClose(f(jnp.array([vals] * ndevices)),
|
||||
jnp.array([sum(vals)] * ndevices),
|
||||
check_dtypes=True)
|
||||
jnp.array([sum(vals)] * ndevices))
|
||||
|
||||
def testPostProcessMap(self):
|
||||
# code from https://github.com/google/jax/issues/2787
|
||||
@ -1221,7 +1220,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.sum(x, 0)
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testOneDevice(self):
|
||||
if xla_bridge.device_count() == 1:
|
||||
@ -1236,8 +1235,8 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
r0 = f0(x)
|
||||
r1 = f1(x)
|
||||
expected = np.expand_dims(np.dot(x.squeeze(), x.squeeze().T), 0)
|
||||
self.assertAllClose(r0, expected, check_dtypes=True, atol=1e-6, rtol=1e-3)
|
||||
self.assertAllClose(r1, expected, check_dtypes=True, atol=1e-6, rtol=1e-3)
|
||||
self.assertAllClose(r0, expected, atol=1e-6, rtol=1e-3)
|
||||
self.assertAllClose(r1, expected, atol=1e-6, rtol=1e-3)
|
||||
|
||||
def testNoDevicesError(self):
|
||||
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[])
|
||||
@ -1303,7 +1302,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
ndevices = xla_bridge.device_count()
|
||||
ans = foo(jnp.ones((ndevices, 1)))
|
||||
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices * 2
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testPmapInJit(self):
|
||||
@jit
|
||||
@ -1316,7 +1315,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
ndevices = xla_bridge.device_count()
|
||||
ans = foo(jnp.ones((ndevices, 1)))
|
||||
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testGradBasic(self):
|
||||
@partial(pmap, axis_name='i', devices=xla_bridge.devices())
|
||||
|
@ -144,15 +144,15 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
if jtu.device_under_test() != "tpu":
|
||||
bits8 = random._random_bits(key, 8, (3,))
|
||||
expected8 = np.array([216, 115, 43], dtype=np.uint8)
|
||||
self.assertArraysEqual(bits8, expected8, check_dtypes=True)
|
||||
self.assertArraysEqual(bits8, expected8)
|
||||
|
||||
bits16 = random._random_bits(key, 16, (3,))
|
||||
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
|
||||
self.assertArraysEqual(bits16, expected16, check_dtypes=True)
|
||||
self.assertArraysEqual(bits16, expected16)
|
||||
|
||||
bits32 = random._random_bits(key, 32, (3,))
|
||||
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
|
||||
self.assertArraysEqual(bits32, expected32, check_dtypes=True)
|
||||
self.assertArraysEqual(bits32, expected32)
|
||||
|
||||
bits64 = random._random_bits(key, 64, (3,))
|
||||
if FLAGS.jax_enable_x64:
|
||||
@ -160,7 +160,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
7882654074788531506], dtype=np.uint64)
|
||||
else:
|
||||
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
|
||||
self.assertArraysEqual(bits64, expected64, check_dtypes=True)
|
||||
self.assertArraysEqual(bits64, expected64)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype)}
|
||||
@ -227,7 +227,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
with self.assertWarns(FutureWarning):
|
||||
perm2 = crand(key)
|
||||
|
||||
self.assertAllClose(perm1, perm2, check_dtypes=True)
|
||||
self.assertAllClose(perm1, perm2)
|
||||
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
|
||||
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
|
||||
|
||||
@ -245,12 +245,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
perm1 = rand(key)
|
||||
perm2 = crand(key)
|
||||
|
||||
self.assertAllClose(perm1, perm2, check_dtypes=True)
|
||||
self.assertAllClose(perm1, perm2)
|
||||
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
|
||||
self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
|
||||
self.assertArraysAllClose(
|
||||
x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype),
|
||||
check_dtypes=True)
|
||||
x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))
|
||||
|
||||
def testPermutationInteger(self):
|
||||
key = random.PRNGKey(0)
|
||||
@ -261,7 +260,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
perm1 = rand(key)
|
||||
perm2 = crand(key)
|
||||
|
||||
self.assertAllClose(perm1, perm2, check_dtypes=True)
|
||||
self.assertAllClose(perm1, perm2)
|
||||
self.assertEqual(perm1.dtype, perm2.dtype)
|
||||
self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely!
|
||||
self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)
|
||||
@ -380,7 +379,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
compiled_samples = crand(key, alpha)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype), check_dtypes=True)
|
||||
self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype))
|
||||
alpha_sum = sum(alpha)
|
||||
for i, a in enumerate(alpha):
|
||||
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
|
||||
@ -435,7 +434,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
pdf = scipy.stats.gamma.pdf(z, alpha)
|
||||
expected_grad = -cdf_dot / pdf
|
||||
|
||||
self.assertAllClose(actual_grad, expected_grad, check_dtypes=True,
|
||||
self.assertAllClose(actual_grad, expected_grad,
|
||||
rtol=2e-2 if jtu.device_under_test() == "tpu" else 5e-4)
|
||||
|
||||
def testGammaGradType(self):
|
||||
@ -670,28 +669,24 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
random.randint(k, (3, 3), 0, 8),
|
||||
np.array([[7, 2, 6],
|
||||
[2, 1, 0],
|
||||
[6, 7, 7]], dtype='int64'),
|
||||
check_dtypes=True)
|
||||
[6, 7, 7]], dtype='int64'))
|
||||
else:
|
||||
self.assertAllClose(
|
||||
random.randint(k, (3, 3), 0, 8),
|
||||
np.array([[2, 1, 3],
|
||||
[6, 1, 5],
|
||||
[6, 3, 4]], dtype='int32'),
|
||||
check_dtypes=True)
|
||||
[6, 3, 4]], dtype='int32'))
|
||||
|
||||
self.assertAllClose(
|
||||
random.split(k, 4),
|
||||
np.array([[2285895361, 1501764800],
|
||||
[1518642379, 4090693311],
|
||||
[ 433833334, 4221794875],
|
||||
[ 839183663, 3740430601]], dtype='uint32'),
|
||||
check_dtypes=True)
|
||||
[ 839183663, 3740430601]], dtype='uint32'))
|
||||
|
||||
self.assertAllClose(
|
||||
random.fold_in(k, 4),
|
||||
np.array([2285895361, 433833334], dtype='uint32'),
|
||||
check_dtypes=True)
|
||||
np.array([2285895361, 433833334], dtype='uint32'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -103,11 +103,9 @@ class NdimageTest(jtu.JaxTestCase):
|
||||
if dtype in float_dtypes:
|
||||
epsilon = max([dtypes.finfo(dtypes.canonicalize_dtype(d)).eps
|
||||
for d in [dtype, coords_dtype]])
|
||||
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=100*epsilon,
|
||||
check_dtypes=True)
|
||||
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=100*epsilon)
|
||||
else:
|
||||
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0,
|
||||
check_dtypes=True)
|
||||
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0)
|
||||
|
||||
def testMapCoordinatesErrors(self):
|
||||
x = onp.arange(5.0)
|
||||
@ -137,7 +135,7 @@ class NdimageTest(jtu.JaxTestCase):
|
||||
|
||||
lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order)
|
||||
osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order)
|
||||
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker)
|
||||
|
||||
def testContinuousGradients(self):
|
||||
# regression test for https://github.com/google/jax/issues/3024
|
||||
|
@ -64,7 +64,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
|
||||
tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-8}
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format(
|
||||
@ -87,7 +87,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
|
||||
tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-14}
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -61,8 +61,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True,
|
||||
rtol={onp.float64: 1e-14})
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={onp.float64: 1e-14})
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
def testPoissonPmf(self, rng_factory, shapes, dtypes):
|
||||
@ -80,7 +79,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
def testBernoulliLogPmf(self, rng_factory, shapes, dtypes):
|
||||
@ -97,7 +96,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
def testGeomLogPmf(self, rng_factory, shapes, dtypes):
|
||||
@ -114,7 +113,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5, jtu.rand_positive)
|
||||
def testBetaLogPdf(self, rng_factory, shapes, dtypes):
|
||||
@ -128,7 +127,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True,
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={onp.float32: 2e-3, onp.float64: 1e-4})
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
@ -145,7 +144,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(2, jtu.rand_positive)
|
||||
def testDirichletLogPdf(self, rng_factory, shapes, dtypes):
|
||||
@ -162,7 +161,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_positive)
|
||||
def testExponLogPdf(self, rng_factory, shapes, dtypes):
|
||||
@ -176,7 +175,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4, jtu.rand_positive)
|
||||
def testGammaLogPdf(self, rng_factory, shapes, dtypes):
|
||||
@ -190,7 +189,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_positive)
|
||||
def testLaplaceLogPdf(self, rng_factory, shapes, dtypes):
|
||||
@ -206,7 +205,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
def testLaplaceCdf(self, rng_factory, shapes, dtypes):
|
||||
@ -222,7 +221,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol={onp.float32: 1e-5, onp.float64: 1e-6})
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1, jtu.rand_default)
|
||||
def testLogisticCdf(self, rng_factory, shapes, dtypes):
|
||||
@ -235,7 +234,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1, jtu.rand_default)
|
||||
def testLogisticLogpdf(self, rng_factory, shapes, dtypes):
|
||||
@ -248,7 +247,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1, jtu.rand_default)
|
||||
def testLogisticPpf(self, rng_factory, shapes, dtypes):
|
||||
@ -261,7 +260,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1, jtu.rand_default)
|
||||
def testLogisticSf(self, rng_factory, shapes, dtypes):
|
||||
@ -274,7 +273,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
def testNormLogPdf(self, rng_factory, shapes, dtypes):
|
||||
@ -290,7 +289,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
@ -307,7 +306,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
@ -324,7 +323,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
@ -341,9 +340,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
return [q, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, rtol=3e-4)
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(4, jtu.rand_positive)
|
||||
@ -358,7 +356,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(4, jtu.rand_default)
|
||||
@ -375,7 +373,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
@ -390,7 +388,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testIssue972(self):
|
||||
self.assertAllClose(
|
||||
@ -455,9 +453,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
|
||||
lsp_stats.multivariate_normal.logpdf,
|
||||
args_maker, check_dtypes=True, tol=1e-3)
|
||||
args_maker, tol=1e-3)
|
||||
self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
|
||||
check_dtypes=True, rtol=1e-4, atol=1e-4)
|
||||
rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -127,18 +127,18 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
b, a = center(jnp.arange(3))
|
||||
self.assertEqual(a.shape, (3,))
|
||||
self.assertEqual(b.shape, ())
|
||||
self.assertAllClose(1.0, b, False)
|
||||
self.assertAllClose(1.0, b, check_dtypes=False)
|
||||
|
||||
X = jnp.arange(12).reshape((3, 4))
|
||||
b, a = center(X, axis=1)
|
||||
self.assertEqual(a.shape, (3, 4))
|
||||
self.assertEqual(b.shape, (3,))
|
||||
self.assertAllClose(jnp.mean(X, axis=1), b, True)
|
||||
self.assertAllClose(jnp.mean(X, axis=1), b)
|
||||
|
||||
b, a = center(X, axis=0)
|
||||
self.assertEqual(a.shape, (3, 4))
|
||||
self.assertEqual(b.shape, (4,))
|
||||
self.assertAllClose(jnp.mean(X, axis=0), b, True)
|
||||
self.assertAllClose(jnp.mean(X, axis=0), b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user