add multigammaln, entr

This commit is contained in:
fehiepsi 2019-08-01 19:12:03 -04:00
parent c41677fac7
commit 7f4dc87a4c
2 changed files with 40 additions and 0 deletions

View File

@ -99,6 +99,27 @@ def xlog1py(x, y):
return lax._safe_mul(x, lax.log1p(y))
@_wraps(osp_special.entr)
def entr(x):
x, = _promote_args_like(osp_special.entr, x)
return lax.select(lax.lt(x, _constant_like(x, 0)),
lax.full_like(x, -onp.inf),
lax.neg(xlogy(x, x)))
@_wraps(osp_special.multigammaln)
def multigammaln(a, d):
a, = _promote_args_like(lambda a: osp_special.multigammaln(a, 1), a)
d = lax.convert_element_type(d, lax.dtype(a))
constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
lax.sub(d, _constant_like(a, 1))),
lax.log(_constant_like(a, onp.pi)))
res = np.sum(gammaln(np.expand_dims(a, axis=-1) -
lax.div(np.arange(d), _constant_like(a, 2))),
axis=-1)
return res + constant
# Normal distributions
# Functions "ndtr" and "ndtri" are derived from calculations made in:

View File

@ -68,6 +68,7 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
op_record("log_ndtr", 1, float_dtypes, jtu.rand_default(), True),
op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0.05, 0.95), True),
op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True),
op_record("entr", 1, float_dtypes, jtu.rand_default(), True),
]
CombosWithReplacement = itertools.combinations_with_replacement
@ -121,6 +122,24 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
if test_autodiff:
jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_d={}".format(
jtu.format_shape_dtype_string(shape, dtype), d),
"rng": jtu.rand_positive(), "shape": shape, "dtype": dtype, "d": d}
for shape in all_shapes
for dtype in float_dtypes
for d in [1, 2, 5]))
def testMultigammaln(self, rng, shape, dtype, d):
def scipy_fun(a):
return osp_special.multigammaln(a, d)
def lax_fun(a):
return lsp_special.multigammaln(a, d)
args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
def testIssue980(self):
x = onp.full((4,), -1e20, dtype=onp.float32)
self.assertAllClose(onp.zeros((4,), dtype=onp.float32),