mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
add multigammaln, entr
This commit is contained in:
parent
c41677fac7
commit
7f4dc87a4c
@ -99,6 +99,27 @@ def xlog1py(x, y):
|
||||
return lax._safe_mul(x, lax.log1p(y))
|
||||
|
||||
|
||||
@_wraps(osp_special.entr)
|
||||
def entr(x):
|
||||
x, = _promote_args_like(osp_special.entr, x)
|
||||
return lax.select(lax.lt(x, _constant_like(x, 0)),
|
||||
lax.full_like(x, -onp.inf),
|
||||
lax.neg(xlogy(x, x)))
|
||||
|
||||
|
||||
@_wraps(osp_special.multigammaln)
|
||||
def multigammaln(a, d):
|
||||
a, = _promote_args_like(lambda a: osp_special.multigammaln(a, 1), a)
|
||||
d = lax.convert_element_type(d, lax.dtype(a))
|
||||
constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
|
||||
lax.sub(d, _constant_like(a, 1))),
|
||||
lax.log(_constant_like(a, onp.pi)))
|
||||
res = np.sum(gammaln(np.expand_dims(a, axis=-1) -
|
||||
lax.div(np.arange(d), _constant_like(a, 2))),
|
||||
axis=-1)
|
||||
return res + constant
|
||||
|
||||
|
||||
# Normal distributions
|
||||
|
||||
# Functions "ndtr" and "ndtri" are derived from calculations made in:
|
||||
|
@ -68,6 +68,7 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
op_record("log_ndtr", 1, float_dtypes, jtu.rand_default(), True),
|
||||
op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0.05, 0.95), True),
|
||||
op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True),
|
||||
op_record("entr", 1, float_dtypes, jtu.rand_default(), True),
|
||||
]
|
||||
|
||||
CombosWithReplacement = itertools.combinations_with_replacement
|
||||
@ -121,6 +122,24 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
if test_autodiff:
|
||||
jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_d={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), d),
|
||||
"rng": jtu.rand_positive(), "shape": shape, "dtype": dtype, "d": d}
|
||||
for shape in all_shapes
|
||||
for dtype in float_dtypes
|
||||
for d in [1, 2, 5]))
|
||||
def testMultigammaln(self, rng, shape, dtype, d):
|
||||
def scipy_fun(a):
|
||||
return osp_special.multigammaln(a, d)
|
||||
|
||||
def lax_fun(a):
|
||||
return lsp_special.multigammaln(a, d)
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
def testIssue980(self):
|
||||
x = onp.full((4,), -1e20, dtype=onp.float32)
|
||||
self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user