mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Expose logsumexp as scipy.special.logsumexp.
scipy.misc.logsumexp is deprecated and appears slated to be removed entirely in scipy 1.3.
This commit is contained in:
parent
daf3e3ffc7
commit
95483c76e9
@ -27,7 +27,7 @@ import numpy.random as npr
|
||||
|
||||
from jax.api import jit, grad
|
||||
from jax.config import config
|
||||
from jax.scipy.misc import logsumexp
|
||||
from jax.scipy.special import logsumexp
|
||||
import jax.numpy as np
|
||||
from examples import datasets
|
||||
|
||||
|
@ -31,7 +31,7 @@ from six.moves import reduce
|
||||
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax.scipy.misc import logsumexp
|
||||
from jax.scipy.special import logsumexp
|
||||
import jax.numpy as np
|
||||
|
||||
|
||||
|
@ -20,18 +20,9 @@ import numpy as onp
|
||||
import scipy.misc as osp_misc
|
||||
|
||||
from .. import lax
|
||||
from ..scipy import special
|
||||
from ..numpy.lax_numpy import _wraps, _reduction_dims, _constant_like
|
||||
|
||||
|
||||
@_wraps(osp_misc.logsumexp)
|
||||
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
if b is not None or return_sign:
|
||||
raise NotImplementedError("Only implemented for b=None, return_sign=False")
|
||||
dims = _reduction_dims(a, axis)
|
||||
shape = lax.subvals(onp.shape(a), zip(dims, (1,) * len(dims)))
|
||||
dimadd = lambda x: lax.reshape(x, shape)
|
||||
amax = lax.reduce(a, _constant_like(a, -onp.inf), lax.max, dims)
|
||||
amax_singletons = dimadd(amax)
|
||||
out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
|
||||
_constant_like(a, 0), lax.add, dims)), amax)
|
||||
return dimadd(out) if keepdims else out
|
||||
if hasattr(osp_misc, 'logsumexp'):
|
||||
logsumexp = special.logsumexp
|
||||
|
@ -16,10 +16,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as onp
|
||||
import scipy.special as osp_special
|
||||
|
||||
from .. import lax
|
||||
from ..numpy.lax_numpy import _wraps, asarray
|
||||
from ..numpy.lax_numpy import _wraps, asarray, _reduction_dims, _constant_like
|
||||
|
||||
|
||||
# need to create new functions because _wraps sets the __name__ attribute
|
||||
@ -41,3 +42,17 @@ def expit(x):
|
||||
x = asarray(x)
|
||||
one = lax._const(x, 1)
|
||||
return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
|
||||
|
||||
|
||||
@_wraps(osp_special.logsumexp)
|
||||
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
if b is not None or return_sign:
|
||||
raise NotImplementedError("Only implemented for b=None, return_sign=False")
|
||||
dims = _reduction_dims(a, axis)
|
||||
shape = lax.subvals(onp.shape(a), zip(dims, (1,) * len(dims)))
|
||||
dimadd = lambda x: lax.reshape(x, shape)
|
||||
amax = lax.reduce(a, _constant_like(a, -onp.inf), lax.max, dims)
|
||||
amax_singletons = dimadd(amax)
|
||||
out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
|
||||
_constant_like(a, 0), lax.add, dims)), amax)
|
||||
return dimadd(out) if keepdims else out
|
@ -87,10 +87,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
|
||||
# TODO(mattjj): test autodiff
|
||||
def scipy_fun(array_to_reduce):
|
||||
return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)
|
||||
return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)
|
||||
|
||||
def lax_fun(array_to_reduce):
|
||||
return lsp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)
|
||||
return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user