diff --git a/tests/BUILD b/tests/BUILD index 91f0cdec8..bc61f1bfe 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -394,15 +394,13 @@ jax_test( jax_test( name = "lax_scipy_test", srcs = ["lax_scipy_test.py"], - backend_tags = { - "tpu": ["noasan"], # Test times out. - }, shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, + "cpu": 20, + "gpu": 20, + "tpu": 20, "iree": 10, }, + deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) jax_test( @@ -419,6 +417,32 @@ jax_test( }, ) +jax_test( + name = "lax_scipy_special_functions_test", + srcs = ["lax_scipy_special_functions_test.py"], + shard_count = { + "cpu": 20, + "gpu": 20, + "tpu": 20, + "iree": 10, + }, + deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), +) + +jax_test( + name = "lax_scipy_spectral_dac_test", + srcs = ["lax_scipy_spectral_dac_test.py"], + shard_count = { + "cpu": 40, + "gpu": 40, + "tpu": 40, + "iree": 40, + }, + deps = [ + "//jax:internal_test_util", + ] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), +) + jax_test( name = "lax_test", srcs = ["lax_test.py"], diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py new file mode 100644 index 000000000..111e29012 --- /dev/null +++ b/tests/lax_scipy_special_functions_test.py @@ -0,0 +1,182 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import functools +import itertools + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np +import scipy.special as osp_special + +import jax +from jax._src import test_util as jtu +from jax.scipy import special as lsp_special + +from jax.config import config +config.parse_flags_with_absl() +FLAGS = config.FLAGS + + +all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] + +OpRecord = collections.namedtuple( + "OpRecord", + ["name", "nargs", "dtypes", "rng_factory", "test_autodiff", "nondiff_argnums", "test_name"]) + + +def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), test_name=None): + test_name = test_name or name + nondiff_argnums = tuple(sorted(set(nondiff_argnums))) + return OpRecord(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums, test_name) + + +float_dtypes = jtu.dtypes.floating +int_dtypes = jtu.dtypes.integer + +# TODO(phawkins): we should probably separate out the function domains used for +# autodiff tests from the function domains used for equivalence testing. For +# example, logit should closely match its scipy equivalent everywhere, but we +# don't expect numerical gradient tests to pass for inputs very close to 0. + +JAX_SPECIAL_FUNCTION_RECORDS = [ + op_record( + "betaln", 2, float_dtypes, jtu.rand_positive, False + ), + op_record( + "betainc", 3, float_dtypes, jtu.rand_positive, False + ), + op_record( + "digamma", 1, float_dtypes, jtu.rand_positive, True + ), + op_record( + "gammainc", 2, float_dtypes, jtu.rand_positive, True + ), + op_record( + "gammaincc", 2, float_dtypes, jtu.rand_positive, True + ), + op_record( + "erf", 1, float_dtypes, jtu.rand_small_positive, True + ), + op_record( + "erfc", 1, float_dtypes, jtu.rand_small_positive, True + ), + op_record( + "erfinv", 1, float_dtypes, jtu.rand_small_positive, True + ), + op_record( + "expit", 1, float_dtypes, jtu.rand_small_positive, True + ), + # TODO: gammaln has slightly high error. + op_record( + "gammaln", 1, float_dtypes, jtu.rand_positive, False + ), + op_record( + "i0", 1, float_dtypes, jtu.rand_default, True + ), + op_record( + "i0e", 1, float_dtypes, jtu.rand_default, True + ), + op_record( + "i1", 1, float_dtypes, jtu.rand_default, True + ), + op_record( + "i1e", 1, float_dtypes, jtu.rand_default, True + ), + op_record( + "logit", 1, float_dtypes, + functools.partial(jtu.rand_uniform, low=0.05, high=0.95), True), + op_record( + "log_ndtr", 1, float_dtypes, jtu.rand_default, True + ), + op_record( + "ndtri", 1, float_dtypes, + functools.partial(jtu.rand_uniform, low=0.05, high=0.95), True, + ), + op_record( + "ndtr", 1, float_dtypes, jtu.rand_default, True + ), + # TODO(phawkins): gradient of entr yields NaNs. + op_record( + "entr", 1, float_dtypes, jtu.rand_default, False + ), + op_record( + "polygamma", 2, (int_dtypes, float_dtypes), + jtu.rand_positive, True, (0,)), + op_record( + "xlogy", 2, float_dtypes, jtu.rand_positive, True + ), + op_record( + "xlog1py", 2, float_dtypes, jtu.rand_default, True + ), + # TODO: enable gradient test for zeta by restricting the domain of + # of inputs to some reasonable intervals + op_record("zeta", 2, float_dtypes, jtu.rand_positive, False), + # TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise + op_record( + "expi", 1, [np.float32], + functools.partial(jtu.rand_not_small, offset=0.1), True), + op_record("exp1", 1, [np.float32], jtu.rand_positive, True), + op_record( + "expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)), +] + + +class LaxScipySpcialFunctionsTest(jtu.JaxTestCase): + + def _GetArgsMaker(self, rng, shapes, dtypes): + return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] + + @parameterized.parameters(itertools.chain.from_iterable( + jtu.sample_product_testcases( + [dict(op=rec.name, rng_factory=rec.rng_factory, + test_autodiff=rec.test_autodiff, + nondiff_argnums=rec.nondiff_argnums)], + shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs), + dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs) + if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)), + ) + for rec in JAX_SPECIAL_FUNCTION_RECORDS + )) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion + def testScipySpecialFun(self, op, rng_factory, shapes, dtypes, + test_autodiff, nondiff_argnums): + scipy_op = getattr(osp_special, op) + lax_op = getattr(lsp_special, op) + rng = rng_factory(self.rng()) + args_maker = self._GetArgsMaker(rng, shapes, dtypes) + args = args_maker() + self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, + check_dtypes=False) + self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) + + if test_autodiff: + def partial_lax_op(*vals): + list_args = list(vals) + for i in nondiff_argnums: + list_args.insert(i, args[i]) + return lax_op(*list_args) + + assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) + diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums] + jtu.check_grads(partial_lax_op, diff_args, order=1, + atol=jtu.if_device_under_test("tpu", .1, 1e-3), + rtol=.1, eps=1e-3) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_scipy_spectral_dac_test.py b/tests/lax_scipy_spectral_dac_test.py new file mode 100644 index 000000000..683e81a34 --- /dev/null +++ b/tests/lax_scipy_spectral_dac_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from jax import lax +from jax import numpy as jnp +from jax._src import test_util as jtu +from jax._src.lax import eigh as lax_eigh + +from absl.testing import absltest + +from jax.config import config +config.parse_flags_with_absl() +FLAGS = config.FLAGS + + +linear_sizes = [16, 97, 128] + + +class LaxScipySpectralDacTest(jtu.JaxTestCase): + + @jtu.sample_product( + linear_size=linear_sizes, + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + termination_size=[1, 19], + ) + def test_spectral_dac_eigh(self, linear_size, dtype, termination_size): + if jtu.device_under_test() != "tpu" and termination_size != 1: + raise unittest.SkipTest( + "Termination sizes greater than 1 only work on TPU") + + rng = self.rng() + H = rng.randn(linear_size, linear_size) + H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype) + if jnp.dtype(dtype).name in ("bfloat16", "float16"): + self.assertRaises( + NotImplementedError, lax_eigh.eigh, H) + return + evs, V = lax_eigh.eigh(H, termination_size=termination_size) + ev_exp, _ = jnp.linalg.eigh(H) + HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST) + vV = evs.astype(V.dtype)[None, :] * V + eps = jnp.finfo(H.dtype).eps + atol = jnp.linalg.norm(H) * eps + self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol) + self.assertAllClose( + HV, vV, atol=atol * (140 if jnp.issubdtype(dtype, jnp.complexfloating) + else 40)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 748a3cdff..ca431eb9b 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -13,14 +13,12 @@ # limitations under the License. -import collections import functools from functools import partial import itertools import unittest from absl.testing import absltest -from absl.testing import parameterized import numpy as np import scipy.special as osp_special @@ -34,7 +32,6 @@ from jax.tree_util import tree_map from jax._src import test_util as jtu from jax.scipy import special as lsp_special from jax.scipy import cluster as lsp_cluster -from jax._src.lax import eigh as lax_eigh from jax.config import config config.parse_flags_with_absl() @@ -61,8 +58,6 @@ sides = ["right", "left"] methods = ["qdwh", "svd"] seeds = [1, 10] -linear_sizes = [16, 128, 256] - def _initialize_polar_test(rng, shape, n_zero_svs, degeneracy, geometric_spectrum, max_sv, nonzero_condition_number, dtype): @@ -94,67 +89,10 @@ def _initialize_polar_test(rng, shape, n_zero_svs, degeneracy, geometric_spectru spectrum = jnp.array(svs).astype(dtype) return result, spectrum -OpRecord = collections.namedtuple( - "OpRecord", - ["name", "nargs", "dtypes", "rng_factory", "test_autodiff", "nondiff_argnums", "test_name"]) - - -def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), test_name=None): - test_name = test_name or name - nondiff_argnums = tuple(sorted(set(nondiff_argnums))) - return OpRecord(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums, test_name) - -# TODO(phawkins): we should probably separate out the function domains used for -# autodiff tests from the function domains used for equivalence testing. For -# example, logit should closely match its scipy equivalent everywhere, but we -# don't expect numerical gradient tests to pass for inputs very close to 0. - -JAX_SPECIAL_FUNCTION_RECORDS = [ - op_record("betaln", 2, float_dtypes, jtu.rand_positive, False), - op_record("betainc", 3, float_dtypes, jtu.rand_positive, False), - op_record("digamma", 1, float_dtypes, jtu.rand_positive, True), - op_record("gammainc", 2, float_dtypes, jtu.rand_positive, True), - op_record("gammaincc", 2, float_dtypes, jtu.rand_positive, True), - op_record("erf", 1, float_dtypes, jtu.rand_small_positive, True), - op_record("erfc", 1, float_dtypes, jtu.rand_small_positive, True), - op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive, True), - op_record("expit", 1, float_dtypes, jtu.rand_small_positive, True), - # TODO: gammaln has slightly high error. - op_record("gammaln", 1, float_dtypes, jtu.rand_positive, False), - op_record("i0", 1, float_dtypes, jtu.rand_default, True), - op_record("i0e", 1, float_dtypes, jtu.rand_default, True), - op_record("i1", 1, float_dtypes, jtu.rand_default, True), - op_record("i1e", 1, float_dtypes, jtu.rand_default, True), - op_record("logit", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05, - high=0.95), True), - op_record("log_ndtr", 1, float_dtypes, jtu.rand_default, True), - op_record("ndtri", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05, - high=0.95), - True), - op_record("ndtr", 1, float_dtypes, jtu.rand_default, True), - # TODO(phawkins): gradient of entr yields NaNs. - op_record("entr", 1, float_dtypes, jtu.rand_default, False), - op_record("polygamma", 2, (int_dtypes, float_dtypes), jtu.rand_positive, True, (0,)), - op_record("xlogy", 2, float_dtypes, jtu.rand_positive, True), - op_record("xlog1py", 2, float_dtypes, jtu.rand_default, True), - # TODO: enable gradient test for zeta by restricting the domain of - # of inputs to some reasonable intervals - op_record("zeta", 2, float_dtypes, jtu.rand_positive, False), - # TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise - op_record("expi", 1, [np.float32], partial(jtu.rand_not_small, offset=0.1), - True), - op_record("exp1", 1, [np.float32], jtu.rand_positive, True), - op_record("expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)), -] - class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" - def _GetArgsMaker(self, rng, shapes, dtypes): - return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] - - @jtu.sample_product( [dict(shapes=shapes, axis=axis, use_b=use_b) for shape_group in compatible_shapes @@ -232,43 +170,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase): result = lsp_special.logsumexp(1.0, b=1.0) self.assertEqual(result, 1.0) - @parameterized.parameters(itertools.chain.from_iterable( - jtu.sample_product_testcases( - [dict(op=rec.name, rng_factory=rec.rng_factory, - test_autodiff=rec.test_autodiff, - nondiff_argnums=rec.nondiff_argnums)], - shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs), - dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs) - if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)), - ) - for rec in JAX_SPECIAL_FUNCTION_RECORDS - )) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion - def testScipySpecialFun(self, op, rng_factory, shapes, dtypes, - test_autodiff, nondiff_argnums): - scipy_op = getattr(osp_special, op) - lax_op = getattr(lsp_special, op) - rng = rng_factory(self.rng()) - args_maker = self._GetArgsMaker(rng, shapes, dtypes) - args = args_maker() - self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, - check_dtypes=False) - self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) - - if test_autodiff: - def partial_lax_op(*vals): - list_args = list(vals) - for i in nondiff_argnums: - list_args.insert(i, args[i]) - return lax_op(*list_args) - - assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) - diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums] - jtu.check_grads(partial_lax_op, diff_args, order=1, - atol=jtu.if_device_under_test("tpu", .1, 1e-3), - rtol=.1, eps=1e-3) - @jtu.sample_product( shape=all_shapes, dtype=float_dtypes, @@ -581,34 +482,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose( matrix, recon, atol=tol * jnp.linalg.norm(matrix)) - @jtu.sample_product( - linear_size=linear_sizes, - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - termination_size=[1, 19], - ) - def test_spectral_dac_eigh(self, linear_size, dtype, termination_size): - if jtu.device_under_test() != "tpu" and termination_size != 1: - raise unittest.SkipTest( - "Termination sizes greater than 1 only work on TPU") - - rng = self.rng() - H = rng.randn(linear_size, linear_size) - H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype) - if jnp.dtype(dtype).name in ("bfloat16", "float16"): - self.assertRaises( - NotImplementedError, lax_eigh.eigh, H) - return - evs, V = lax_eigh.eigh(H, termination_size=termination_size) - ev_exp, eV_exp = jnp.linalg.eigh(H) - HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST) - vV = evs.astype(V.dtype)[None, :] * V - eps = jnp.finfo(H.dtype).eps - atol = jnp.linalg.norm(H) * eps - self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol) - self.assertAllClose( - HV, vV, atol=atol * (140 if jnp.issubdtype(dtype, jnp.complexfloating) - else 30)) - @jtu.sample_product( n_obs=[1, 3, 5], n_codes=[1, 2, 4], @@ -645,6 +518,5 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertArraysEqual(actual, nan_array, check_dtypes=False) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())