mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings. Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist. Either fix or locally suppress warnings in tests. Also fix crashes on Mac related to a preexisting linear algebra bug. * Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
This commit is contained in:
parent
453dc5f085
commit
2dc81fb40c
@ -47,5 +47,5 @@ script:
|
||||
echo "===== Checking with mypy ====" &&
|
||||
time mypy --config-file=mypy.ini jax ;
|
||||
else
|
||||
pytest tests examples -W ignore ;
|
||||
pytest tests examples ;
|
||||
fi
|
||||
|
@ -372,8 +372,6 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
|
||||
not dtypes.issubdtype(new_dtype, onp.complexfloating)):
|
||||
msg = "Casting complex values to real discards the imaginary part"
|
||||
warnings.warn(msg, onp.ComplexWarning, stacklevel=2)
|
||||
operand = real(operand)
|
||||
old_dtype = _dtype(operand)
|
||||
return convert_element_type_p.bind(
|
||||
operand, new_dtype=new_dtype, old_dtype=old_dtype)
|
||||
|
||||
@ -2109,15 +2107,21 @@ def _convert_element_type_dtype_rule(operand, *, new_dtype, old_dtype):
|
||||
return new_dtype
|
||||
|
||||
def _convert_element_type_translation_rule(c, operand, *, new_dtype, old_dtype):
|
||||
if (dtypes.issubdtype(old_dtype, onp.complexfloating) and
|
||||
not dtypes.issubdtype(new_dtype, onp.complexfloating)):
|
||||
operand = c.Real(operand)
|
||||
new_etype = xla_client.dtype_to_etype(new_dtype)
|
||||
return c.ConvertElementType(operand, new_element_type=new_etype)
|
||||
|
||||
def _convert_element_type_transpose_rule(t, *, new_dtype, old_dtype):
|
||||
assert t.dtype == new_dtype, (t.dtype, new_dtype)
|
||||
return [convert_element_type_p.bind(t, new_dtype=old_dtype,
|
||||
old_dtype=new_dtype)]
|
||||
|
||||
convert_element_type_p = standard_primitive(
|
||||
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
|
||||
'convert_element_type', _convert_element_type_translation_rule)
|
||||
ad.deflinear(
|
||||
convert_element_type_p,
|
||||
lambda t, new_dtype, old_dtype: [convert_element_type(t, old_dtype)])
|
||||
ad.deflinear(convert_element_type_p, _convert_element_type_transpose_rule)
|
||||
batching.defvectorized(convert_element_type_p)
|
||||
masking.defvectorized(convert_element_type_p)
|
||||
|
||||
|
@ -94,9 +94,10 @@ def _rfft_transpose(t, fft_lengths):
|
||||
# asymptotic complexity and is also rather complicated), we rely JAX to
|
||||
# transpose a naive RFFT implementation.
|
||||
dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
|
||||
dummy_primals = lax.full_like(t, 0.0, onp.float64, dummy_shape)
|
||||
dummy_primals = lax.full_like(t, 0.0, _real_dtype(t.dtype), dummy_shape)
|
||||
_, jvpfun = vjp(partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primals)
|
||||
result, = jvpfun(t)
|
||||
assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype)
|
||||
return result
|
||||
|
||||
def _irfft_transpose(t, fft_lengths):
|
||||
@ -107,14 +108,16 @@ def _irfft_transpose(t, fft_lengths):
|
||||
x = fft(t, xla_client.FftType.RFFT, fft_lengths)
|
||||
n = x.shape[-1]
|
||||
is_odd = fft_lengths[-1] % 2
|
||||
full = partial(lax.full_like, t, dtype=onp.float64)
|
||||
full = partial(lax.full_like, t, dtype=t.dtype)
|
||||
mask = lax.concatenate(
|
||||
[full(1.0, shape=(1,)),
|
||||
full(2.0, shape=(n - 2 + is_odd,)),
|
||||
full(1.0, shape=(1 - is_odd,))],
|
||||
dimension=0)
|
||||
scale = 1 / prod(fft_lengths)
|
||||
return scale * mask * x
|
||||
out = scale * mask * x
|
||||
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
|
||||
return out
|
||||
|
||||
def fft_transpose_rule(t, fft_type, fft_lengths):
|
||||
if fft_type == xla_client.FftType.RFFT:
|
||||
|
@ -266,7 +266,7 @@ sort = onp.sort
|
||||
def sort_key_val(keys, values, dimension=-1):
|
||||
idxs = list(onp.ix_(*[onp.arange(d) for d in keys.shape]))
|
||||
idxs[dimension] = onp.argsort(keys, axis=dimension)
|
||||
return keys[idxs], values[idxs]
|
||||
return keys[tuple(idxs)], values[tuple(idxs)]
|
||||
|
||||
# TODO untake
|
||||
|
||||
|
@ -3435,7 +3435,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
|
||||
|
||||
|
||||
@_wraps(onp.corrcoef)
|
||||
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None):
|
||||
def corrcoef(x, y=None, rowvar=True):
|
||||
c = cov(x, y, rowvar)
|
||||
if len(shape(c)) == 0:
|
||||
# scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise
|
||||
|
@ -20,7 +20,8 @@ import itertools as it
|
||||
import os
|
||||
from typing import Dict, Sequence, Union
|
||||
import sys
|
||||
from unittest import SkipTest
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -336,7 +337,8 @@ def supported_dtypes():
|
||||
|
||||
def skip_if_unsupported_type(dtype):
|
||||
if dtype not in supported_dtypes():
|
||||
raise SkipTest(f"Type {dtype} not supported on {device_under_test()}")
|
||||
raise unittest.SkipTest(
|
||||
f"Type {dtype} not supported on {device_under_test()}")
|
||||
|
||||
def skip_on_devices(*disabled_devices):
|
||||
"""A decorator for test methods to skip the test on certain devices."""
|
||||
@ -346,8 +348,8 @@ def skip_on_devices(*disabled_devices):
|
||||
device = device_under_test()
|
||||
if device in disabled_devices:
|
||||
test_name = getattr(test_method, '__name__', '[unknown test]')
|
||||
raise SkipTest('{} not supported on {}.'
|
||||
.format(test_name, device.upper()))
|
||||
raise unittest.SkipTest(
|
||||
f"{test_name} not supported on {device.upper()}.")
|
||||
return test_method(self, *args, **kwargs)
|
||||
return test_method_wrapper
|
||||
return skip
|
||||
@ -361,16 +363,17 @@ def skip_on_flag(flag_name, skip_value):
|
||||
flag_value = getattr(FLAGS, flag_name)
|
||||
if flag_value == skip_value:
|
||||
test_name = getattr(test_method, '__name__', '[unknown test]')
|
||||
raise SkipTest('{} not supported when FLAGS.{} is {}'
|
||||
.format(test_name, flag_name, flag_value))
|
||||
raise unittest.SkipTest(
|
||||
f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}")
|
||||
return test_method(self, *args, **kwargs)
|
||||
return test_method_wrapper
|
||||
return skip
|
||||
|
||||
# TODO(phawkins): bug https://github.com/google/jax/issues/432
|
||||
def skip_on_mac_xla_bug():
|
||||
if sys.platform == "darwin" and scipy.version.version > "1.0.0":
|
||||
raise absltest.SkipTest("Test fails on Mac with new scipy (issue #432)")
|
||||
skip_on_mac_linalg_bug = partial(
|
||||
unittest.skipIf,
|
||||
sys.platform == "darwin" and scipy.version.version > "1.1.0",
|
||||
"Test fails on Mac with new scipy (issue #432)")
|
||||
|
||||
|
||||
def format_test_name_suffix(opname, shapes, dtypes):
|
||||
@ -793,3 +796,10 @@ class JaxTestCase(parameterized.TestCase):
|
||||
numpy_ans = numpy_reference_op(*args)
|
||||
self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ignore_warning(**kw):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", **kw)
|
||||
yield
|
||||
|
6
pytest.ini
Normal file
6
pytest.ini
Normal file
@ -0,0 +1,6 @@
|
||||
[pytest]
|
||||
filterwarnings =
|
||||
error
|
||||
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
|
||||
ignore:Explicitly requested dtype.*is not available.*:UserWarning
|
||||
ignore:jax.experimental.vectorize is deprecated.*:FutureWarning
|
@ -212,7 +212,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Try using `x.astype\({}\)` instead.".format(castfun.__name__),
|
||||
f"Try using `x.astype\\({castfun.__name__}\\)` instead.",
|
||||
lambda: jit(f)(1.0))
|
||||
|
||||
def test_switch_value_jit(self):
|
||||
@ -1018,7 +1018,7 @@ class APITest(jtu.JaxTestCase):
|
||||
X = onp.random.randn(10, 4)
|
||||
U = onp.random.randn(10, 2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
||||
@ -1028,7 +1028,7 @@ class APITest(jtu.JaxTestCase):
|
||||
"arg 1 has an axis to be mapped of size 2"):
|
||||
api.vmap(h, in_axes=(0, 1))(X, U)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
||||
@ -1039,32 +1039,38 @@ class APITest(jtu.JaxTestCase):
|
||||
"arg 1 has an axis to be mapped of size 2"):
|
||||
api.vmap(lambda x, y, z: None, in_axes=(0, 1, 0))(X, U, X)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
"the tree of axis sizes is:\n"
|
||||
r"\(10, \[2, 2\]\)"):
|
||||
api.vmap(h, in_axes=(0, 1))(X, [U, U])
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "vmap got arg 0 of rank 0 but axis to be mapped 0"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "vmap got arg 0 of rank 0 but axis to be mapped 0"):
|
||||
# The mapped inputs cannot be scalars
|
||||
api.vmap(lambda x: x)(1.)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, re.escape("vmap got arg 0 of rank 1 but axis to be mapped [1. 2.]")):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, re.escape("vmap got arg 0 of rank 1 but axis to be mapped [1. 2.]")):
|
||||
api.vmap(lambda x: x, in_axes=(np.array([1., 2.]),))(np.array([1., 2.]))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "vmap must have at least one non-None in_axes"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "vmap must have at least one non-None in_axes"):
|
||||
# If the output is mapped, there must be a non-None in_axes
|
||||
api.vmap(lambda x: x, in_axes=None)(np.array([1., 2.]))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "vmap got arg 0 of rank 1 but axis to be mapped 1"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "vmap got arg 0 of rank 1 but axis to be mapped 1"):
|
||||
api.vmap(lambda x: x, in_axes=1)(np.array([1., 2.]))
|
||||
|
||||
# Error is: TypeError: only integer scalar arrays can be converted to a scalar index
|
||||
with self.assertRaisesRegexp(ValueError, "axes specification must be a tree prefix of the corresponding value"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "axes specification must be a tree prefix of the corresponding value"):
|
||||
api.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))(np.array([1., 2.]))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "vmap has mapped output but out_axes is None"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "vmap has mapped output but out_axes is None"):
|
||||
# If the output is mapped, then there must be some out_axes specified
|
||||
api.vmap(lambda x: x, out_axes=None)(np.array([1., 2.]))
|
||||
|
||||
|
@ -49,6 +49,11 @@ all_dtypes = float_dtypes + int_dtypes + bool_types
|
||||
IndexSpec = collections.namedtuple("IndexTest", ["shape", "indexer"])
|
||||
|
||||
|
||||
suppress_deprecated_indexing_warnings = partial(
|
||||
jtu.ignore_warning, category=FutureWarning,
|
||||
message='Using a non-tuple sequence.*')
|
||||
|
||||
|
||||
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
|
||||
# TODO(mattjj,dougalm): add higher-order check
|
||||
default_tol = 1e-6 if FLAGS.jax_enable_x64 else 1e-2
|
||||
@ -548,7 +553,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
|
||||
rng = rng_factory()
|
||||
args_maker = lambda: [rng(shape, dtype), indexer]
|
||||
fun = lambda x, idx: x[idx]
|
||||
fun = lambda x, idx: jnp.asarray(x)[idx]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -608,7 +613,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
rng = rng_factory()
|
||||
tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
|
||||
arg = rng(shape, dtype)
|
||||
fun = lambda x: x[indexer]**2
|
||||
fun = lambda x: jnp.asarray(x)[indexer]**2
|
||||
check_grads(fun, (arg,), 2, tol, tol, tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -629,7 +634,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
|
||||
def fun(x, indexer_with_dummies):
|
||||
idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
|
||||
return x[idx]
|
||||
return jnp.asarray(x)[idx]
|
||||
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -803,6 +808,7 @@ def _broadcastable_shapes(shape):
|
||||
for x in f(list(reversed(shape))):
|
||||
yield list(reversed(x))
|
||||
|
||||
@suppress_deprecated_indexing_warnings()
|
||||
def _update_shape(shape, indexer):
|
||||
return onp.zeros(shape)[indexer].shape
|
||||
|
||||
@ -813,6 +819,7 @@ class UpdateOps(enum.Enum):
|
||||
MIN = 2
|
||||
MAX = 3
|
||||
|
||||
@suppress_deprecated_indexing_warnings()
|
||||
def onp_fn(op, indexer, x, y):
|
||||
x = x.copy()
|
||||
x[indexer] = {
|
||||
|
@ -448,6 +448,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
JAX_COMPOUND_OP_RECORDS)))
|
||||
def testOp(self, onp_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes,
|
||||
tolerance, inexact):
|
||||
if onp_op is onp.float_power:
|
||||
onp_op = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="invalid value.*")(onp_op)
|
||||
|
||||
rng = rng_factory()
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False)
|
||||
tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes)
|
||||
@ -580,12 +584,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testReducer(self, onp_op, jnp_op, rng_factory, shape, dtype, out_dtype,
|
||||
axis, keepdims, inexact):
|
||||
rng = rng_factory()
|
||||
@jtu.ignore_warning(category=onp.ComplexWarning)
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="mean of empty slice.*")
|
||||
def onp_fun(x):
|
||||
x_cast = x if dtype != jnp.bfloat16 else x.astype(onp.float32)
|
||||
t = out_dtype if out_dtype != jnp.bfloat16 else onp.float32
|
||||
return onp_op(x_cast, axis, dtype=t, keepdims=keepdims)
|
||||
onp_fun = _promote_like_jnp(onp_fun, inexact)
|
||||
jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol_spec = {onp.float16: 1e-2, onp.float32: 1e-3, onp.complex64: 1e-3,
|
||||
onp.float64: 1e-5, onp.complex128: 1e-5}
|
||||
@ -613,7 +621,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rng = rng_factory()
|
||||
onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
|
||||
onp_fun = _promote_like_jnp(onp_fun, inexact)
|
||||
onp_fun = jtu.ignore_warning(category=onp.ComplexWarning)(onp_fun)
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
@ -640,6 +650,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testNonzero(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero()
|
||||
onp_fun = lambda x: onp.nonzero(x)
|
||||
onp_fun = jtu.ignore_warning(
|
||||
category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*")(onp_fun)
|
||||
jnp_fun = lambda x: jnp.nonzero(x)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
@ -1158,7 +1171,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, jnp_op, rng_factory):
|
||||
rng = rng_factory()
|
||||
onp_fun = lambda arg: onp_op(arg, axis=axis, dtype=out_dtype)
|
||||
onp_fun = jtu.ignore_warning(category=onp.ComplexWarning)(onp_fun)
|
||||
jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
@ -1278,10 +1293,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
||||
x1_rng = x1_rng_factory()
|
||||
x2_rng = x2_rng_factory()
|
||||
def onp_fun(x1, x2):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
return onp.ldexp(x1, x2)
|
||||
onp_fun = lambda x1, x2: onp.ldexp(x1, x2)
|
||||
onp_fun = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="overflow.*")(onp_fun)
|
||||
jnp_fun = lambda x1, x2: jnp.ldexp(x1, x2)
|
||||
args_maker = lambda: [x1_rng(x1_shape, x1_dtype),
|
||||
x2_rng(x2_shape, onp.int32)]
|
||||
@ -2237,6 +2251,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testWhereOneArgument(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero()
|
||||
onp_fun = lambda x: onp.where(x)
|
||||
onp_fun = jtu.ignore_warning(
|
||||
category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*")(onp_fun)
|
||||
jnp_fun = lambda x: jnp.where(x)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
@ -2461,6 +2478,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
"log", "expm1", "log1p")))
|
||||
def testMathSpecialFloatValues(self, op, dtype):
|
||||
onp_op = getattr(onp, op)
|
||||
onp_op = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="invalid value.*")(onp_op)
|
||||
onp_op = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="divide by zero.*")(onp_op)
|
||||
onp_op = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="overflow.*")(onp_op)
|
||||
|
||||
jnp_op = getattr(jnp, op)
|
||||
dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type
|
||||
for x in (onp.nan, -onp.inf, -100., -2., -1., 0., 1., 2., 100., onp.inf,
|
||||
@ -2559,26 +2583,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
|
||||
shape, dtype, rowvar, ddof, bias),
|
||||
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
|
||||
"bias": bias, "rng_factory": rng_factory}
|
||||
{"testcase_name": "_shape={}_dtype={}_rowvar={}".format(
|
||||
shape, dtype, rowvar),
|
||||
"shape": shape, "dtype": dtype, "rowvar": rowvar,
|
||||
"rng_factory": rng_factory}
|
||||
for shape in [(5,), (10, 5), (3, 10)]
|
||||
for dtype in number_dtypes
|
||||
for rowvar in [True, False]
|
||||
for bias in [True, False]
|
||||
for ddof in [None, 2, 3]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testCorrCoef(self, shape, dtype, rowvar, ddof, bias, rng_factory):
|
||||
def testCorrCoef(self, shape, dtype, rowvar, rng_factory):
|
||||
rng = rng_factory()
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
mat = onp.asarray([rng(shape, dtype)])
|
||||
onp_fun = partial(onp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias)
|
||||
jnp_fun = partial(jnp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias)
|
||||
onp_fun = partial(onp.corrcoef, rowvar=rowvar)
|
||||
jnp_fun = partial(jnp.corrcoef, rowvar=rowvar)
|
||||
if not onp.any(onp.isclose(onp.std(mat), 0.0)):
|
||||
self._CheckAgainstNumpy(
|
||||
onp_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-2 if jtu.device_under_test() == "tpu" else None)
|
||||
self._CheckAgainstNumpy(
|
||||
onp_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-2 if jtu.device_under_test() == "tpu" else None)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
@ -154,7 +154,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
A = lambda x: x
|
||||
b = jnp.zeros((2, 1))
|
||||
x0 = jnp.zeros((2,))
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "x0 and b must have matching shape"):
|
||||
jax.scipy.sparse.linalg.cg(A, b, x0)
|
||||
|
||||
|
@ -1901,6 +1901,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
|
||||
args = (rng((2, 3), from_dtype),)
|
||||
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
|
||||
convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)(
|
||||
convert_element_type)
|
||||
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -3111,6 +3113,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for padding in ["VALID", "SAME"]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
@jtu.ignore_warning(message="Using reduced precision for gradient.*")
|
||||
def testSelectAndGatherAdd(self, dtype, padding, rng_factory):
|
||||
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
|
||||
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
|
||||
|
@ -62,6 +62,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testCholesky(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
@ -121,11 +122,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for nq in zip([2, 4, 6, 36], [(1, 2), (2, 2), (1, 2, 3), (3, 3, 1, 4)])
|
||||
for dtype in float_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testTensorsolve(self, m, nq, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if m == 23:
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
|
||||
# According to numpy docs the shapes are as follows:
|
||||
# Coefficient tensor (a), of shape b.shape + Q.
|
||||
@ -159,6 +159,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testSlogdet(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
@ -200,6 +201,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
||||
# for GPU/TPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testEig(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
@ -228,11 +230,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
# TODO: enable when there is an eigendecomposition implementation
|
||||
# for GPU/TPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testEigvals(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if shape == (50, 50) and dtype == onp.complex64:
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
n = shape[-1]
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
@ -622,10 +623,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for shape in [(1, 1), (4, 4), (200, 200), (7, 7, 7, 7)]
|
||||
for dtype in float_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testTensorinv(self, shape, dtype, rng_factory):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if shape[0] > 100:
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
rng = rng_factory()
|
||||
|
||||
def tensor_maker():
|
||||
@ -677,11 +677,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)]
|
||||
for dtype in float_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testInv(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if shape == (200, 200) and dtype == onp.float32:
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
if jtu.device_under_test() == "gpu" and shape == (200, 200):
|
||||
raise unittest.SkipTest("Test is flaky on GPU")
|
||||
|
||||
@ -708,11 +707,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_devices("tpu") # SVD is not implemented on the TPU backend
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testPinv(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if shape == (7, 10000) and dtype in [onp.complex64, onp.float32]:
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker,
|
||||
@ -824,6 +822,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testLu(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
@ -884,11 +883,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for n in [1, 4, 5, 200]
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testLuFactor(self, n, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if n == 200 and dtype == onp.complex64:
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
x, = args_maker()
|
||||
@ -1109,11 +1107,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for n in [1, 4, 5, 20, 50, 100]
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testExpm(self, n, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if n == 50 and dtype in [onp.complex64, onp.float32]:
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
osp_fun = lambda a: osp.linalg.expm(a)
|
||||
@ -1135,8 +1132,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for n in [1, 4, 5, 20, 50, 100]
|
||||
for dtype in float_types + complex_types
|
||||
))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testIssue2131(self, n, dtype):
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
args_maker_zeros = lambda: [onp.zeros((n, n), dtype)]
|
||||
osp_fun = lambda a: osp.linalg.expm(a)
|
||||
jsp_fun = lambda a: jsp.linalg.expm(a)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from unittest import SkipTest, skip
|
||||
@ -32,6 +32,9 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
ignore_soft_pmap_warning = functools.partial(
|
||||
jtu.ignore_warning, message="soft_pmap is an experimental.*")
|
||||
|
||||
class PapplyTest(jtu.JaxTestCase):
|
||||
|
||||
def testIdentity(self):
|
||||
@ -46,6 +49,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
expected = onp.sin(onp.arange(3.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testSum(self):
|
||||
pfun, axis_name = _papply(lambda x: np.sum(x, axis=0))
|
||||
|
||||
@ -59,6 +63,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
expected = onp.sum(arg, axis=0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testMax(self):
|
||||
pfun, axis_name = _papply(lambda x: np.max(x, axis=0))
|
||||
|
||||
@ -72,6 +77,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
expected = onp.max(arg, axis=0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testSelect(self):
|
||||
p = onp.arange(15).reshape((5, 3)) % 4 == 1
|
||||
f = onp.zeros((5, 3))
|
||||
@ -101,6 +107,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
expected = fun(onp.arange(1., 5.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testAdd(self):
|
||||
x = onp.array([[1, 2, 3], [4, 5, 6]])
|
||||
expected = x + x
|
||||
|
@ -61,6 +61,9 @@ def tearDownModule():
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
ignore_soft_pmap_warning = partial(
|
||||
jtu.ignore_warning, message="soft_pmap is an experimental.*")
|
||||
|
||||
|
||||
class PmapTest(jtu.JaxTestCase):
|
||||
def _getMeshShape(self, device_mesh_shape):
|
||||
@ -434,7 +437,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
g = lambda: pmap(f, "i")(onp.arange(device_count))
|
||||
self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
"Given `perm` does not represent a real permutation: \[1.*\]", g)
|
||||
"Given `perm` does not represent a real permutation: \\[1.*\\]", g)
|
||||
|
||||
@jtu.skip_on_devices("cpu", "gpu")
|
||||
def testPpermuteWithZipObject(self):
|
||||
@ -772,6 +775,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = onp.swapaxes(x, 0, 2)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testSoftPmapPsum(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
def f(x):
|
||||
@ -780,6 +784,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = onp.ones(n) / n
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testSoftPmapAxisIndex(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
def f(x):
|
||||
@ -788,6 +793,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = 2 * onp.arange(n)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testSoftPmapOfJit(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
def f(x):
|
||||
@ -796,6 +802,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = 3 * onp.arange(n)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testSoftPmapNested(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
|
||||
@ -809,6 +816,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = onp.arange(n ** 2).reshape(n, n).T
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testGradOfSoftPmap(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
|
||||
@ -820,6 +828,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = onp.repeat(onp.arange(n)[:, None], n, axis=1)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
def testSoftPmapDevicePersistence(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
shape = (2 * 2 * device_count, 2, 3)
|
||||
|
@ -175,7 +175,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for p in [0.1, 0.5, 0.9]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
def testBernoulli(self, p, dtype):
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
key = random.PRNGKey(0)
|
||||
p = onp.array(p, dtype=dtype)
|
||||
rand = lambda key, p: random.bernoulli(key, p, (10000,))
|
||||
@ -194,7 +193,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for sample_shape in [(10000,), (5000, 2)]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
def testCategorical(self, p, axis, dtype, sample_shape):
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
key = random.PRNGKey(0)
|
||||
p = onp.array(p, dtype=dtype)
|
||||
logits = onp.log(p) - 42 # test unnormalized
|
||||
@ -245,7 +243,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
def testCauchy(self, dtype):
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.cauchy(key, (10000,), dtype)
|
||||
crand = api.jit(rand)
|
||||
@ -264,7 +261,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
def testDirichlet(self, alpha, dtype):
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
|
||||
crand = api.jit(rand)
|
||||
@ -282,7 +278,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
def testExponential(self, dtype):
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.exponential(key, (10000,), dtype)
|
||||
crand = api.jit(rand)
|
||||
@ -299,7 +294,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for a in [0.1, 1., 10.]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
def testGamma(self, a, dtype):
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
crand = api.jit(rand)
|
||||
@ -346,7 +340,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
def testGumbel(self, dtype):
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.gumbel(key, (10000,), dtype)
|
||||
crand = api.jit(rand)
|
||||
@ -428,8 +421,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
"dim": dim, "dtype": dtype}
|
||||
for dim in [1, 3, 5]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testMultivariateNormal(self, dim, dtype):
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
r = onp.random.RandomState(dim)
|
||||
mean = r.randn(dim)
|
||||
cov_factor = r.randn(dim, dim)
|
||||
@ -453,9 +446,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# eigenvectors follow a standard normal distribution.
|
||||
self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
|
||||
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testMultivariateNormalCovariance(self):
|
||||
# test code based on https://github.com/google/jax/issues/1869
|
||||
jtu.skip_on_mac_xla_bug()
|
||||
N = 100000
|
||||
cov = np.array([[ 0.19, 0.00, -0.13, 0.00],
|
||||
[ 0.00, 0.29, 0.00, -0.23],
|
||||
|
@ -14,10 +14,12 @@
|
||||
|
||||
|
||||
import collections
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax.lib
|
||||
from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
|
||||
@ -157,11 +159,13 @@ class TreeTest(jtu.JaxTestCase):
|
||||
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
|
||||
|
||||
@parameterized.parameters(*TREES)
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 44), "Jaxlib too old")
|
||||
def testAllLeavesWithTrees(self, tree):
|
||||
leaves = tree_util.tree_leaves(tree)
|
||||
self.assertTrue(tree_util.all_leaves(leaves))
|
||||
self.assertFalse(tree_util.all_leaves([tree]))
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 44), "Jaxlib too old")
|
||||
@parameterized.parameters(*LEAVES)
|
||||
def testAllLeavesWithLeaves(self, leaf):
|
||||
self.assertTrue(tree_util.all_leaves([leaf]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user