Expose logsumexp as scipy.special.logsumexp.

scipy.misc.logsumexp is deprecated and appears slated to be removed entirely in scipy 1.3.
This commit is contained in:
Peter Hawkins 2019-02-24 11:49:15 -05:00
parent daf3e3ffc7
commit 95483c76e9
5 changed files with 23 additions and 17 deletions

View File

@ -27,7 +27,7 @@ import numpy.random as npr
from jax.api import jit, grad
from jax.config import config
from jax.scipy.misc import logsumexp
from jax.scipy.special import logsumexp
import jax.numpy as np
from examples import datasets

View File

@ -31,7 +31,7 @@ from six.moves import reduce
from jax import lax
from jax import random
from jax.scipy.misc import logsumexp
from jax.scipy.special import logsumexp
import jax.numpy as np

View File

@ -20,18 +20,9 @@ import numpy as onp
import scipy.misc as osp_misc
from .. import lax
from ..scipy import special
from ..numpy.lax_numpy import _wraps, _reduction_dims, _constant_like
@_wraps(osp_misc.logsumexp)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
if b is not None or return_sign:
raise NotImplementedError("Only implemented for b=None, return_sign=False")
dims = _reduction_dims(a, axis)
shape = lax.subvals(onp.shape(a), zip(dims, (1,) * len(dims)))
dimadd = lambda x: lax.reshape(x, shape)
amax = lax.reduce(a, _constant_like(a, -onp.inf), lax.max, dims)
amax_singletons = dimadd(amax)
out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
_constant_like(a, 0), lax.add, dims)), amax)
return dimadd(out) if keepdims else out
if hasattr(osp_misc, 'logsumexp'):
logsumexp = special.logsumexp

View File

@ -16,10 +16,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as onp
import scipy.special as osp_special
from .. import lax
from ..numpy.lax_numpy import _wraps, asarray
from ..numpy.lax_numpy import _wraps, asarray, _reduction_dims, _constant_like
# need to create new functions because _wraps sets the __name__ attribute
@ -41,3 +42,17 @@ def expit(x):
x = asarray(x)
one = lax._const(x, 1)
return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
@_wraps(osp_special.logsumexp)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
if b is not None or return_sign:
raise NotImplementedError("Only implemented for b=None, return_sign=False")
dims = _reduction_dims(a, axis)
shape = lax.subvals(onp.shape(a), zip(dims, (1,) * len(dims)))
dimadd = lambda x: lax.reshape(x, shape)
amax = lax.reduce(a, _constant_like(a, -onp.inf), lax.max, dims)
amax_singletons = dimadd(amax)
out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
_constant_like(a, 0), lax.add, dims)), amax)
return dimadd(out) if keepdims else out

View File

@ -87,10 +87,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
# TODO(mattjj): test autodiff
def scipy_fun(array_to_reduce):
return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)
return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)
def lax_fun(array_to_reduce):
return lsp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)
return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)