# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import functools import itertools from absl.testing import absltest from absl.testing import parameterized import numpy as np import scipy import scipy.special as osp_special import jax import jax.numpy as jnp from jax._src import test_util as jtu from jax.scipy import special as lsp_special jax.config.parse_flags_with_absl() all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] OpRecord = collections.namedtuple( "OpRecord", ["name", "nargs", "dtypes", "rng_factory", "test_autodiff", "nondiff_argnums", "test_name"]) def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), test_name=None): test_name = test_name or name nondiff_argnums = tuple(sorted(set(nondiff_argnums))) return OpRecord(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums, test_name) float_dtypes = jtu.dtypes.floating int_dtypes = jtu.dtypes.integer # TODO(phawkins): we should probably separate out the function domains used for # autodiff tests from the function domains used for equivalence testing. For # example, logit should closely match its scipy equivalent everywhere, but we # don't expect numerical gradient tests to pass for inputs very close to 0. JAX_SPECIAL_FUNCTION_RECORDS = [ op_record( "beta", 2, float_dtypes, jtu.rand_default, False ), op_record( "betaln", 2, float_dtypes, jtu.rand_default, False ), op_record( "betainc", 3, float_dtypes, jtu.rand_positive, False ), op_record( "gamma", 1, float_dtypes, jtu.rand_default, True ), op_record( "digamma", 1, float_dtypes, jtu.rand_positive, True ), op_record( "gammainc", 2, float_dtypes, jtu.rand_positive, True ), op_record( "gammaincc", 2, float_dtypes, jtu.rand_positive, True ), op_record( "gammasgn", 1, float_dtypes, jtu.rand_default, True ), op_record( "erf", 1, float_dtypes, jtu.rand_small_positive, True ), op_record( "erfc", 1, float_dtypes, jtu.rand_small_positive, True ), op_record( "erfinv", 1, float_dtypes, jtu.rand_small_positive, True ), op_record( "expit", 1, float_dtypes, jtu.rand_small_positive, True ), # TODO: gammaln has slightly high error. op_record( "gammaln", 1, float_dtypes, jtu.rand_positive, False ), op_record( "factorial", 1, float_dtypes, jtu.rand_default, True ), op_record( "fresnel", 1, float_dtypes, functools.partial(jtu.rand_default, scale=30), True ), op_record( "i0", 1, float_dtypes, jtu.rand_default, True ), op_record( # Note: values near zero can fail numeric gradient tests. "i0e", 1, float_dtypes, functools.partial(jtu.rand_not_small, offset=0.1), True ), op_record( "i1", 1, float_dtypes, jtu.rand_default, True ), op_record( "i1e", 1, float_dtypes, jtu.rand_default, True ), op_record( "logit", 1, float_dtypes, functools.partial(jtu.rand_uniform, low=0.05, high=0.95), True), op_record( "log_ndtr", 1, float_dtypes, jtu.rand_default, True ), op_record( "ndtri", 1, float_dtypes, functools.partial(jtu.rand_uniform, low=0.0, high=1.0), True, ), op_record( "ndtr", 1, float_dtypes, jtu.rand_default, True ), # TODO(phawkins): gradient of entr yields NaNs. op_record( "entr", 1, float_dtypes, jtu.rand_default, False ), op_record( "polygamma", 2, (int_dtypes, float_dtypes), jtu.rand_positive, True, (0,)), op_record( "xlogy", 2, float_dtypes, jtu.rand_positive, True ), op_record( "xlog1py", 2, float_dtypes, jtu.rand_default, True ), op_record("zeta", 2, float_dtypes, jtu.rand_positive, True), # TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise op_record( "expi", 1, [np.float32], functools.partial(jtu.rand_not_small, offset=0.1), True), op_record("exp1", 1, [np.float32], jtu.rand_positive, True), op_record( "expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)), op_record("kl_div", 2, float_dtypes, jtu.rand_positive, True), op_record( "rel_entr", 2, float_dtypes, jtu.rand_positive, True, ), op_record("poch", 2, float_dtypes, jtu.rand_positive, True), op_record( "hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True ), op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True), op_record("softmax", 1, float_dtypes, jtu.rand_default, True), ] def _pretty_special_fun_name(case): shapes_str = "_".join("x".join(map(str, shape)) if shape else "s" for shape in case["shapes"]) dtypes_str = "_".join(np.dtype(d).name for d in case["dtypes"]) name = f"_{case['op']}_{shapes_str}_{dtypes_str}" return dict(**case, testcase_name=name) class LaxScipySpcialFunctionsTest(jtu.JaxTestCase): def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(itertools.chain.from_iterable( map(_pretty_special_fun_name, jtu.sample_product_testcases( [dict(op=rec.name, rng_factory=rec.rng_factory, test_autodiff=rec.test_autodiff, nondiff_argnums=rec.nondiff_argnums)], shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs), dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)), )) for rec in JAX_SPECIAL_FUNCTION_RECORDS )) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testScipySpecialFun(self, op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): scipy_op = getattr(osp_special, op) lax_op = getattr(lsp_special, op) rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=.1 if jtu.test_device_matches(["tpu"]) else 1e-3, rtol=.1, eps=1e-3) @jtu.sample_product( n=[0, 1, 2, 3, 10, 50] ) def testScipySpecialFunBernoulli(self, n): dtype = jnp.zeros(0).dtype # default float dtype. scipy_op = lambda: osp_special.bernoulli(n).astype(dtype) lax_op = functools.partial(lsp_special.bernoulli, n) args_maker = lambda: [] self._CheckAgainstNumpy(scipy_op, lax_op, args_maker, atol=0, rtol=1E-5) self._CompileAndCheck(lax_op, args_maker, atol=0, rtol=1E-5) def testGammaSign(self): dtype = jnp.zeros(0).dtype # default float dtype. typ = dtype.type testcases = [ (np.arange(-10, 0).astype(dtype), np.array([np.nan] * 10, dtype=dtype)), (np.nextafter(np.arange(-5, 0).astype(dtype), typ(-np.inf)), np.array([1, -1, 1, -1, 1], dtype=dtype)), (np.nextafter(np.arange(-5, 0).astype(dtype), typ(np.inf)), np.array([-1, 1, -1, 1, -1], dtype=dtype)), (np.arange(0, 10).astype(dtype), np.ones((10,), dtype)), (np.nextafter(np.arange(0, 10).astype(dtype), typ(np.inf)), np.ones((10,), dtype)), (np.nextafter(np.arange(1, 10).astype(dtype), typ(-np.inf)), np.ones((9,), dtype)), (np.array([-np.inf, -0.0, 0.0, np.inf, np.nan]), np.array([np.nan, -1.0, 1.0, 1.0, np.nan])) ] for inp, out in testcases: self.assertArraysEqual(out, lsp_special.gammasgn(inp)) self.assertArraysEqual(out, jnp.sign(lsp_special.gamma(inp))) if jtu.parse_version(scipy.__version__) >= (1, 15): self.assertArraysEqual(out, osp_special.gammasgn(inp)) self.assertAllClose(osp_special.gammasgn(inp), lsp_special.gammasgn(inp)) def testNdtriExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). dtype = jnp.zeros(0).dtype # default float dtype. args_maker = lambda: [np.arange(-10, 10).astype(dtype)] rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.ndtri, args_maker, rtol=rtol) def testRelEntrExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). dtype = jnp.zeros(0).dtype # default float dtype. args_maker = lambda: [np.array([-2, -2, -2, -1, -1, -1, 0, 0, 0]).astype(dtype), np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]).astype(dtype)] rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(osp_special.rel_entr, lsp_special.rel_entr, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.rel_entr, args_maker, rtol=rtol) def testBetaParameterDeprecation(self): with self.assertNoWarnings(): lsp_special.beta(1, 1) lsp_special.beta(1, b=1) lsp_special.beta(a=1, b=1) with self.assertRaises(TypeError): lsp_special.beta(x=1, y=1) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())