From 7f4dc87a4c33a0b56631632b9b8b12c96db71514 Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Thu, 1 Aug 2019 19:12:03 -0400 Subject: [PATCH] add multigammaln, entr --- jax/scipy/special.py | 21 +++++++++++++++++++++ tests/lax_scipy_test.py | 19 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/jax/scipy/special.py b/jax/scipy/special.py index c2e0ef692..5b242a05a 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -99,6 +99,27 @@ def xlog1py(x, y): return lax._safe_mul(x, lax.log1p(y)) +@_wraps(osp_special.entr) +def entr(x): + x, = _promote_args_like(osp_special.entr, x) + return lax.select(lax.lt(x, _constant_like(x, 0)), + lax.full_like(x, -onp.inf), + lax.neg(xlogy(x, x))) + + +@_wraps(osp_special.multigammaln) +def multigammaln(a, d): + a, = _promote_args_like(lambda a: osp_special.multigammaln(a, 1), a) + d = lax.convert_element_type(d, lax.dtype(a)) + constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d), + lax.sub(d, _constant_like(a, 1))), + lax.log(_constant_like(a, onp.pi))) + res = np.sum(gammaln(np.expand_dims(a, axis=-1) - + lax.div(np.arange(d), _constant_like(a, 2))), + axis=-1) + return res + constant + + # Normal distributions # Functions "ndtr" and "ndtri" are derived from calculations made in: diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index db5ad71dc..d329e57fe 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -68,6 +68,7 @@ JAX_SPECIAL_FUNCTION_RECORDS = [ op_record("log_ndtr", 1, float_dtypes, jtu.rand_default(), True), op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0.05, 0.95), True), op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True), + op_record("entr", 1, float_dtypes, jtu.rand_default(), True), ] CombosWithReplacement = itertools.combinations_with_replacement @@ -121,6 +122,24 @@ class LaxBackedScipyTests(jtu.JaxTestCase): if test_autodiff: jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_d={}".format( + jtu.format_shape_dtype_string(shape, dtype), d), + "rng": jtu.rand_positive(), "shape": shape, "dtype": dtype, "d": d} + for shape in all_shapes + for dtype in float_dtypes + for d in [1, 2, 5])) + def testMultigammaln(self, rng, shape, dtype, d): + def scipy_fun(a): + return osp_special.multigammaln(a, d) + + def lax_fun(a): + return lsp_special.multigammaln(a, d) + + args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) + def testIssue980(self): x = onp.full((4,), -1e20, dtype=onp.float32) self.assertAllClose(onp.zeros((4,), dtype=onp.float32),