Make pytest run over JAX tests warning clean, and error on warnings. (#2674)

* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
This commit is contained in:
Peter Hawkins 2020-04-12 15:35:35 -04:00 committed by GitHub
parent 453dc5f085
commit 2dc81fb40c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 147 additions and 76 deletions

View File

@ -47,5 +47,5 @@ script:
echo "===== Checking with mypy ====" &&
time mypy --config-file=mypy.ini jax ;
else
pytest tests examples -W ignore ;
pytest tests examples ;
fi

View File

@ -372,8 +372,6 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
not dtypes.issubdtype(new_dtype, onp.complexfloating)):
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, onp.ComplexWarning, stacklevel=2)
operand = real(operand)
old_dtype = _dtype(operand)
return convert_element_type_p.bind(
operand, new_dtype=new_dtype, old_dtype=old_dtype)
@ -2109,15 +2107,21 @@ def _convert_element_type_dtype_rule(operand, *, new_dtype, old_dtype):
return new_dtype
def _convert_element_type_translation_rule(c, operand, *, new_dtype, old_dtype):
if (dtypes.issubdtype(old_dtype, onp.complexfloating) and
not dtypes.issubdtype(new_dtype, onp.complexfloating)):
operand = c.Real(operand)
new_etype = xla_client.dtype_to_etype(new_dtype)
return c.ConvertElementType(operand, new_element_type=new_etype)
def _convert_element_type_transpose_rule(t, *, new_dtype, old_dtype):
assert t.dtype == new_dtype, (t.dtype, new_dtype)
return [convert_element_type_p.bind(t, new_dtype=old_dtype,
old_dtype=new_dtype)]
convert_element_type_p = standard_primitive(
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
'convert_element_type', _convert_element_type_translation_rule)
ad.deflinear(
convert_element_type_p,
lambda t, new_dtype, old_dtype: [convert_element_type(t, old_dtype)])
ad.deflinear(convert_element_type_p, _convert_element_type_transpose_rule)
batching.defvectorized(convert_element_type_p)
masking.defvectorized(convert_element_type_p)

View File

@ -94,9 +94,10 @@ def _rfft_transpose(t, fft_lengths):
# asymptotic complexity and is also rather complicated), we rely JAX to
# transpose a naive RFFT implementation.
dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
dummy_primals = lax.full_like(t, 0.0, onp.float64, dummy_shape)
dummy_primals = lax.full_like(t, 0.0, _real_dtype(t.dtype), dummy_shape)
_, jvpfun = vjp(partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primals)
result, = jvpfun(t)
assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype)
return result
def _irfft_transpose(t, fft_lengths):
@ -107,14 +108,16 @@ def _irfft_transpose(t, fft_lengths):
x = fft(t, xla_client.FftType.RFFT, fft_lengths)
n = x.shape[-1]
is_odd = fft_lengths[-1] % 2
full = partial(lax.full_like, t, dtype=onp.float64)
full = partial(lax.full_like, t, dtype=t.dtype)
mask = lax.concatenate(
[full(1.0, shape=(1,)),
full(2.0, shape=(n - 2 + is_odd,)),
full(1.0, shape=(1 - is_odd,))],
dimension=0)
scale = 1 / prod(fft_lengths)
return scale * mask * x
out = scale * mask * x
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
return out
def fft_transpose_rule(t, fft_type, fft_lengths):
if fft_type == xla_client.FftType.RFFT:

View File

@ -266,7 +266,7 @@ sort = onp.sort
def sort_key_val(keys, values, dimension=-1):
idxs = list(onp.ix_(*[onp.arange(d) for d in keys.shape]))
idxs[dimension] = onp.argsort(keys, axis=dimension)
return keys[idxs], values[idxs]
return keys[tuple(idxs)], values[tuple(idxs)]
# TODO untake

View File

@ -3435,7 +3435,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
@_wraps(onp.corrcoef)
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None):
def corrcoef(x, y=None, rowvar=True):
c = cov(x, y, rowvar)
if len(shape(c)) == 0:
# scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise

View File

@ -20,7 +20,8 @@ import itertools as it
import os
from typing import Dict, Sequence, Union
import sys
from unittest import SkipTest
import unittest
import warnings
from absl.testing import absltest
from absl.testing import parameterized
@ -336,7 +337,8 @@ def supported_dtypes():
def skip_if_unsupported_type(dtype):
if dtype not in supported_dtypes():
raise SkipTest(f"Type {dtype} not supported on {device_under_test()}")
raise unittest.SkipTest(
f"Type {dtype} not supported on {device_under_test()}")
def skip_on_devices(*disabled_devices):
"""A decorator for test methods to skip the test on certain devices."""
@ -346,8 +348,8 @@ def skip_on_devices(*disabled_devices):
device = device_under_test()
if device in disabled_devices:
test_name = getattr(test_method, '__name__', '[unknown test]')
raise SkipTest('{} not supported on {}.'
.format(test_name, device.upper()))
raise unittest.SkipTest(
f"{test_name} not supported on {device.upper()}.")
return test_method(self, *args, **kwargs)
return test_method_wrapper
return skip
@ -361,16 +363,17 @@ def skip_on_flag(flag_name, skip_value):
flag_value = getattr(FLAGS, flag_name)
if flag_value == skip_value:
test_name = getattr(test_method, '__name__', '[unknown test]')
raise SkipTest('{} not supported when FLAGS.{} is {}'
.format(test_name, flag_name, flag_value))
raise unittest.SkipTest(
f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}")
return test_method(self, *args, **kwargs)
return test_method_wrapper
return skip
# TODO(phawkins): bug https://github.com/google/jax/issues/432
def skip_on_mac_xla_bug():
if sys.platform == "darwin" and scipy.version.version > "1.0.0":
raise absltest.SkipTest("Test fails on Mac with new scipy (issue #432)")
skip_on_mac_linalg_bug = partial(
unittest.skipIf,
sys.platform == "darwin" and scipy.version.version > "1.1.0",
"Test fails on Mac with new scipy (issue #432)")
def format_test_name_suffix(opname, shapes, dtypes):
@ -793,3 +796,10 @@ class JaxTestCase(parameterized.TestCase):
numpy_ans = numpy_reference_op(*args)
self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes,
atol=tol, rtol=tol)
@contextmanager
def ignore_warning(**kw):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", **kw)
yield

6
pytest.ini Normal file
View File

@ -0,0 +1,6 @@
[pytest]
filterwarnings =
error
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
ignore:Explicitly requested dtype.*is not available.*:UserWarning
ignore:jax.experimental.vectorize is deprecated.*:FutureWarning

View File

@ -212,7 +212,7 @@ class APITest(jtu.JaxTestCase):
self.assertRaisesRegex(
TypeError,
"Try using `x.astype\({}\)` instead.".format(castfun.__name__),
f"Try using `x.astype\\({castfun.__name__}\\)` instead.",
lambda: jit(f)(1.0))
def test_switch_value_jit(self):
@ -1018,7 +1018,7 @@ class APITest(jtu.JaxTestCase):
X = onp.random.randn(10, 4)
U = onp.random.randn(10, 2)
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
@ -1028,7 +1028,7 @@ class APITest(jtu.JaxTestCase):
"arg 1 has an axis to be mapped of size 2"):
api.vmap(h, in_axes=(0, 1))(X, U)
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
@ -1039,32 +1039,38 @@ class APITest(jtu.JaxTestCase):
"arg 1 has an axis to be mapped of size 2"):
api.vmap(lambda x, y, z: None, in_axes=(0, 1, 0))(X, U, X)
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
"the tree of axis sizes is:\n"
r"\(10, \[2, 2\]\)"):
api.vmap(h, in_axes=(0, 1))(X, [U, U])
with self.assertRaisesRegex(ValueError, "vmap got arg 0 of rank 0 but axis to be mapped 0"):
with self.assertRaisesRegex(
ValueError, "vmap got arg 0 of rank 0 but axis to be mapped 0"):
# The mapped inputs cannot be scalars
api.vmap(lambda x: x)(1.)
with self.assertRaisesRegexp(ValueError, re.escape("vmap got arg 0 of rank 1 but axis to be mapped [1. 2.]")):
with self.assertRaisesRegex(
ValueError, re.escape("vmap got arg 0 of rank 1 but axis to be mapped [1. 2.]")):
api.vmap(lambda x: x, in_axes=(np.array([1., 2.]),))(np.array([1., 2.]))
with self.assertRaisesRegex(ValueError, "vmap must have at least one non-None in_axes"):
with self.assertRaisesRegex(
ValueError, "vmap must have at least one non-None in_axes"):
# If the output is mapped, there must be a non-None in_axes
api.vmap(lambda x: x, in_axes=None)(np.array([1., 2.]))
with self.assertRaisesRegexp(ValueError, "vmap got arg 0 of rank 1 but axis to be mapped 1"):
with self.assertRaisesRegex(
ValueError, "vmap got arg 0 of rank 1 but axis to be mapped 1"):
api.vmap(lambda x: x, in_axes=1)(np.array([1., 2.]))
# Error is: TypeError: only integer scalar arrays can be converted to a scalar index
with self.assertRaisesRegexp(ValueError, "axes specification must be a tree prefix of the corresponding value"):
with self.assertRaisesRegex(
ValueError, "axes specification must be a tree prefix of the corresponding value"):
api.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))(np.array([1., 2.]))
with self.assertRaisesRegexp(ValueError, "vmap has mapped output but out_axes is None"):
with self.assertRaisesRegex(
ValueError, "vmap has mapped output but out_axes is None"):
# If the output is mapped, then there must be some out_axes specified
api.vmap(lambda x: x, out_axes=None)(np.array([1., 2.]))

View File

@ -49,6 +49,11 @@ all_dtypes = float_dtypes + int_dtypes + bool_types
IndexSpec = collections.namedtuple("IndexTest", ["shape", "indexer"])
suppress_deprecated_indexing_warnings = partial(
jtu.ignore_warning, category=FutureWarning,
message='Using a non-tuple sequence.*')
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
# TODO(mattjj,dougalm): add higher-order check
default_tol = 1e-6 if FLAGS.jax_enable_x64 else 1e-2
@ -548,7 +553,7 @@ class IndexingTest(jtu.JaxTestCase):
def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
rng = rng_factory()
args_maker = lambda: [rng(shape, dtype), indexer]
fun = lambda x, idx: x[idx]
fun = lambda x, idx: jnp.asarray(x)[idx]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
@ -608,7 +613,7 @@ class IndexingTest(jtu.JaxTestCase):
rng = rng_factory()
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
arg = rng(shape, dtype)
fun = lambda x: x[indexer]**2
fun = lambda x: jnp.asarray(x)[indexer]**2
check_grads(fun, (arg,), 2, tol, tol, tol)
@parameterized.named_parameters(
@ -629,7 +634,7 @@ class IndexingTest(jtu.JaxTestCase):
def fun(x, indexer_with_dummies):
idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
return x[idx]
return jnp.asarray(x)[idx]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@ -803,6 +808,7 @@ def _broadcastable_shapes(shape):
for x in f(list(reversed(shape))):
yield list(reversed(x))
@suppress_deprecated_indexing_warnings()
def _update_shape(shape, indexer):
return onp.zeros(shape)[indexer].shape
@ -813,6 +819,7 @@ class UpdateOps(enum.Enum):
MIN = 2
MAX = 3
@suppress_deprecated_indexing_warnings()
def onp_fn(op, indexer, x, y):
x = x.copy()
x[indexer] = {

View File

@ -448,6 +448,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
JAX_COMPOUND_OP_RECORDS)))
def testOp(self, onp_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes,
tolerance, inexact):
if onp_op is onp.float_power:
onp_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(onp_op)
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False)
tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes)
@ -580,12 +584,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testReducer(self, onp_op, jnp_op, rng_factory, shape, dtype, out_dtype,
axis, keepdims, inexact):
rng = rng_factory()
@jtu.ignore_warning(category=onp.ComplexWarning)
@jtu.ignore_warning(category=RuntimeWarning,
message="mean of empty slice.*")
def onp_fun(x):
x_cast = x if dtype != jnp.bfloat16 else x.astype(onp.float32)
t = out_dtype if out_dtype != jnp.bfloat16 else onp.float32
return onp_op(x_cast, axis, dtype=t, keepdims=keepdims)
onp_fun = _promote_like_jnp(onp_fun, inexact)
jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol_spec = {onp.float16: 1e-2, onp.float32: 1e-3, onp.complex64: 1e-3,
onp.float64: 1e-5, onp.complex128: 1e-5}
@ -613,7 +621,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rng = rng_factory()
onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
onp_fun = _promote_like_jnp(onp_fun, inexact)
onp_fun = jtu.ignore_warning(category=onp.ComplexWarning)(onp_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@ -640,6 +650,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testNonzero(self, shape, dtype):
rng = jtu.rand_some_zero()
onp_fun = lambda x: onp.nonzero(x)
onp_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(onp_fun)
jnp_fun = lambda x: jnp.nonzero(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=False)
@ -1158,7 +1171,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, jnp_op, rng_factory):
rng = rng_factory()
onp_fun = lambda arg: onp_op(arg, axis=axis, dtype=out_dtype)
onp_fun = jtu.ignore_warning(category=onp.ComplexWarning)(onp_fun)
jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
@ -1278,10 +1293,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.skipTest("Only run float64 testcase when float64 is enabled.")
x1_rng = x1_rng_factory()
x2_rng = x2_rng_factory()
def onp_fun(x1, x2):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
return onp.ldexp(x1, x2)
onp_fun = lambda x1, x2: onp.ldexp(x1, x2)
onp_fun = jtu.ignore_warning(category=RuntimeWarning,
message="overflow.*")(onp_fun)
jnp_fun = lambda x1, x2: jnp.ldexp(x1, x2)
args_maker = lambda: [x1_rng(x1_shape, x1_dtype),
x2_rng(x2_shape, onp.int32)]
@ -2237,6 +2251,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testWhereOneArgument(self, shape, dtype):
rng = jtu.rand_some_zero()
onp_fun = lambda x: onp.where(x)
onp_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(onp_fun)
jnp_fun = lambda x: jnp.where(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=False)
@ -2461,6 +2478,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
"log", "expm1", "log1p")))
def testMathSpecialFloatValues(self, op, dtype):
onp_op = getattr(onp, op)
onp_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(onp_op)
onp_op = jtu.ignore_warning(category=RuntimeWarning,
message="divide by zero.*")(onp_op)
onp_op = jtu.ignore_warning(category=RuntimeWarning,
message="overflow.*")(onp_op)
jnp_op = getattr(jnp, op)
dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type
for x in (onp.nan, -onp.inf, -100., -2., -1., 0., 1., 2., 100., onp.inf,
@ -2559,26 +2583,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
shape, dtype, rowvar, ddof, bias),
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
"bias": bias, "rng_factory": rng_factory}
{"testcase_name": "_shape={}_dtype={}_rowvar={}".format(
shape, dtype, rowvar),
"shape": shape, "dtype": dtype, "rowvar": rowvar,
"rng_factory": rng_factory}
for shape in [(5,), (10, 5), (3, 10)]
for dtype in number_dtypes
for rowvar in [True, False]
for bias in [True, False]
for ddof in [None, 2, 3]
for rng_factory in [jtu.rand_default]))
def testCorrCoef(self, shape, dtype, rowvar, ddof, bias, rng_factory):
def testCorrCoef(self, shape, dtype, rowvar, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
mat = onp.asarray([rng(shape, dtype)])
onp_fun = partial(onp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias)
jnp_fun = partial(jnp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias)
onp_fun = partial(onp.corrcoef, rowvar=rowvar)
jnp_fun = partial(jnp.corrcoef, rowvar=rowvar)
if not onp.any(onp.isclose(onp.std(mat), 0.0)):
self._CheckAgainstNumpy(
onp_fun, jnp_fun, args_maker, check_dtypes=False,
tol=1e-2 if jtu.device_under_test() == "tpu" else None)
self._CheckAgainstNumpy(
onp_fun, jnp_fun, args_maker, check_dtypes=False,
tol=1e-2 if jtu.device_under_test() == "tpu" else None)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(

View File

@ -154,7 +154,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
A = lambda x: x
b = jnp.zeros((2, 1))
x0 = jnp.zeros((2,))
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError, "x0 and b must have matching shape"):
jax.scipy.sparse.linalg.cg(A, b, x0)

View File

@ -1901,6 +1901,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
args = (rng((2, 3), from_dtype),)
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)(
convert_element_type)
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
@ -3111,6 +3113,7 @@ class LaxVmapTest(jtu.JaxTestCase):
for padding in ["VALID", "SAME"]
for rng_factory in [jtu.rand_small]))
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.ignore_warning(message="Using reduced precision for gradient.*")
def testSelectAndGatherAdd(self, dtype, padding, rng_factory):
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")

View File

@ -62,6 +62,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_mac_linalg_bug()
def testCholesky(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
@ -121,11 +122,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for nq in zip([2, 4, 6, 36], [(1, 2), (2, 2), (1, 2, 3), (3, 3, 1, 4)])
for dtype in float_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_mac_linalg_bug()
def testTensorsolve(self, m, nq, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if m == 23:
jtu.skip_on_mac_xla_bug()
# According to numpy docs the shapes are as follows:
# Coefficient tensor (a), of shape b.shape + Q.
@ -159,6 +159,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu")
@jtu.skip_on_mac_linalg_bug()
def testSlogdet(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
@ -200,6 +201,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.skip_on_mac_linalg_bug()
def testEig(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
@ -228,11 +230,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# TODO: enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.skip_on_mac_linalg_bug()
def testEigvals(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if shape == (50, 50) and dtype == onp.complex64:
jtu.skip_on_mac_xla_bug()
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
@ -622,10 +623,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for shape in [(1, 1), (4, 4), (200, 200), (7, 7, 7, 7)]
for dtype in float_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_mac_linalg_bug()
def testTensorinv(self, shape, dtype, rng_factory):
_skip_if_unsupported_type(dtype)
if shape[0] > 100:
jtu.skip_on_mac_xla_bug()
rng = rng_factory()
def tensor_maker():
@ -677,11 +677,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)]
for dtype in float_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_mac_linalg_bug()
def testInv(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if shape == (200, 200) and dtype == onp.float32:
jtu.skip_on_mac_xla_bug()
if jtu.device_under_test() == "gpu" and shape == (200, 200):
raise unittest.SkipTest("Test is flaky on GPU")
@ -708,11 +707,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu") # SVD is not implemented on the TPU backend
@jtu.skip_on_mac_linalg_bug()
def testPinv(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if shape == (7, 10000) and dtype in [onp.complex64, onp.float32]:
jtu.skip_on_mac_xla_bug()
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker,
@ -824,6 +822,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_mac_linalg_bug()
def testLu(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
@ -884,11 +883,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for n in [1, 4, 5, 200]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_mac_linalg_bug()
def testLuFactor(self, n, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if n == 200 and dtype == onp.complex64:
jtu.skip_on_mac_xla_bug()
args_maker = lambda: [rng((n, n), dtype)]
x, = args_maker()
@ -1109,11 +1107,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for n in [1, 4, 5, 20, 50, 100]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_small]))
@jtu.skip_on_mac_linalg_bug()
def testExpm(self, n, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if n == 50 and dtype in [onp.complex64, onp.float32]:
jtu.skip_on_mac_xla_bug()
args_maker = lambda: [rng((n, n), dtype)]
osp_fun = lambda a: osp.linalg.expm(a)
@ -1135,8 +1132,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for n in [1, 4, 5, 20, 50, 100]
for dtype in float_types + complex_types
))
@jtu.skip_on_mac_linalg_bug()
def testIssue2131(self, n, dtype):
jtu.skip_on_mac_xla_bug()
args_maker_zeros = lambda: [onp.zeros((n, n), dtype)]
osp_fun = lambda a: osp.linalg.expm(a)
jsp_fun = lambda a: jsp.linalg.expm(a)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
import unittest
from unittest import SkipTest, skip
@ -32,6 +32,9 @@ from jax.config import config
config.parse_flags_with_absl()
ignore_soft_pmap_warning = functools.partial(
jtu.ignore_warning, message="soft_pmap is an experimental.*")
class PapplyTest(jtu.JaxTestCase):
def testIdentity(self):
@ -46,6 +49,7 @@ class PapplyTest(jtu.JaxTestCase):
expected = onp.sin(onp.arange(3.))
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testSum(self):
pfun, axis_name = _papply(lambda x: np.sum(x, axis=0))
@ -59,6 +63,7 @@ class PapplyTest(jtu.JaxTestCase):
expected = onp.sum(arg, axis=0)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testMax(self):
pfun, axis_name = _papply(lambda x: np.max(x, axis=0))
@ -72,6 +77,7 @@ class PapplyTest(jtu.JaxTestCase):
expected = onp.max(arg, axis=0)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testSelect(self):
p = onp.arange(15).reshape((5, 3)) % 4 == 1
f = onp.zeros((5, 3))
@ -101,6 +107,7 @@ class PapplyTest(jtu.JaxTestCase):
expected = fun(onp.arange(1., 5.))
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testAdd(self):
x = onp.array([[1, 2, 3], [4, 5, 6]])
expected = x + x

View File

@ -61,6 +61,9 @@ def tearDownModule():
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
ignore_soft_pmap_warning = partial(
jtu.ignore_warning, message="soft_pmap is an experimental.*")
class PmapTest(jtu.JaxTestCase):
def _getMeshShape(self, device_mesh_shape):
@ -434,7 +437,7 @@ class PmapTest(jtu.JaxTestCase):
g = lambda: pmap(f, "i")(onp.arange(device_count))
self.assertRaisesRegex(
AssertionError,
"Given `perm` does not represent a real permutation: \[1.*\]", g)
"Given `perm` does not represent a real permutation: \\[1.*\\]", g)
@jtu.skip_on_devices("cpu", "gpu")
def testPpermuteWithZipObject(self):
@ -772,6 +775,7 @@ class PmapTest(jtu.JaxTestCase):
expected = onp.swapaxes(x, 0, 2)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testSoftPmapPsum(self):
n = 4 * xla_bridge.device_count()
def f(x):
@ -780,6 +784,7 @@ class PmapTest(jtu.JaxTestCase):
expected = onp.ones(n) / n
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testSoftPmapAxisIndex(self):
n = 4 * xla_bridge.device_count()
def f(x):
@ -788,6 +793,7 @@ class PmapTest(jtu.JaxTestCase):
expected = 2 * onp.arange(n)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testSoftPmapOfJit(self):
n = 4 * xla_bridge.device_count()
def f(x):
@ -796,6 +802,7 @@ class PmapTest(jtu.JaxTestCase):
expected = 3 * onp.arange(n)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testSoftPmapNested(self):
n = 4 * xla_bridge.device_count()
@ -809,6 +816,7 @@ class PmapTest(jtu.JaxTestCase):
expected = onp.arange(n ** 2).reshape(n, n).T
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testGradOfSoftPmap(self):
n = 4 * xla_bridge.device_count()
@ -820,6 +828,7 @@ class PmapTest(jtu.JaxTestCase):
expected = onp.repeat(onp.arange(n)[:, None], n, axis=1)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
def testSoftPmapDevicePersistence(self):
device_count = xla_bridge.device_count()
shape = (2 * 2 * device_count, 2, 3)

View File

@ -175,7 +175,6 @@ class LaxRandomTest(jtu.JaxTestCase):
for p in [0.1, 0.5, 0.9]
for dtype in [onp.float32, onp.float64]))
def testBernoulli(self, p, dtype):
jtu.skip_on_mac_xla_bug()
key = random.PRNGKey(0)
p = onp.array(p, dtype=dtype)
rand = lambda key, p: random.bernoulli(key, p, (10000,))
@ -194,7 +193,6 @@ class LaxRandomTest(jtu.JaxTestCase):
for sample_shape in [(10000,), (5000, 2)]
for dtype in [onp.float32, onp.float64]))
def testCategorical(self, p, axis, dtype, sample_shape):
jtu.skip_on_mac_xla_bug()
key = random.PRNGKey(0)
p = onp.array(p, dtype=dtype)
logits = onp.log(p) - 42 # test unnormalized
@ -245,7 +243,6 @@ class LaxRandomTest(jtu.JaxTestCase):
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
def testCauchy(self, dtype):
jtu.skip_on_mac_xla_bug()
key = random.PRNGKey(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
crand = api.jit(rand)
@ -264,7 +261,6 @@ class LaxRandomTest(jtu.JaxTestCase):
]
for dtype in [onp.float32, onp.float64]))
def testDirichlet(self, alpha, dtype):
jtu.skip_on_mac_xla_bug()
key = random.PRNGKey(0)
rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
crand = api.jit(rand)
@ -282,7 +278,6 @@ class LaxRandomTest(jtu.JaxTestCase):
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
def testExponential(self, dtype):
jtu.skip_on_mac_xla_bug()
key = random.PRNGKey(0)
rand = lambda key: random.exponential(key, (10000,), dtype)
crand = api.jit(rand)
@ -299,7 +294,6 @@ class LaxRandomTest(jtu.JaxTestCase):
for a in [0.1, 1., 10.]
for dtype in [onp.float32, onp.float64]))
def testGamma(self, a, dtype):
jtu.skip_on_mac_xla_bug()
key = random.PRNGKey(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
crand = api.jit(rand)
@ -346,7 +340,6 @@ class LaxRandomTest(jtu.JaxTestCase):
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
def testGumbel(self, dtype):
jtu.skip_on_mac_xla_bug()
key = random.PRNGKey(0)
rand = lambda key: random.gumbel(key, (10000,), dtype)
crand = api.jit(rand)
@ -428,8 +421,8 @@ class LaxRandomTest(jtu.JaxTestCase):
"dim": dim, "dtype": dtype}
for dim in [1, 3, 5]
for dtype in [onp.float32, onp.float64]))
@jtu.skip_on_mac_linalg_bug()
def testMultivariateNormal(self, dim, dtype):
jtu.skip_on_mac_xla_bug()
r = onp.random.RandomState(dim)
mean = r.randn(dim)
cov_factor = r.randn(dim, dim)
@ -453,9 +446,9 @@ class LaxRandomTest(jtu.JaxTestCase):
# eigenvectors follow a standard normal distribution.
self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
@jtu.skip_on_mac_linalg_bug()
def testMultivariateNormalCovariance(self):
# test code based on https://github.com/google/jax/issues/1869
jtu.skip_on_mac_xla_bug()
N = 100000
cov = np.array([[ 0.19, 0.00, -0.13, 0.00],
[ 0.00, 0.29, 0.00, -0.23],

View File

@ -14,10 +14,12 @@
import collections
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax.lib
from jax import test_util as jtu
from jax import tree_util
@ -157,11 +159,13 @@ class TreeTest(jtu.JaxTestCase):
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
@parameterized.parameters(*TREES)
@unittest.skipIf(jax.lib.version < (0, 1, 44), "Jaxlib too old")
def testAllLeavesWithTrees(self, tree):
leaves = tree_util.tree_leaves(tree)
self.assertTrue(tree_util.all_leaves(leaves))
self.assertFalse(tree_util.all_leaves([tree]))
@unittest.skipIf(jax.lib.version < (0, 1, 44), "Jaxlib too old")
@parameterized.parameters(*LEAVES)
def testAllLeavesWithLeaves(self, leaf):
self.assertTrue(tree_util.all_leaves([leaf]))