diff --git a/examples/mnist_classifier_fromscratch.py b/examples/mnist_classifier_fromscratch.py index 419daa057..f79861a58 100644 --- a/examples/mnist_classifier_fromscratch.py +++ b/examples/mnist_classifier_fromscratch.py @@ -27,7 +27,7 @@ import numpy.random as npr from jax.api import jit, grad from jax.config import config -from jax.scipy.misc import logsumexp +from jax.scipy.special import logsumexp import jax.numpy as np from examples import datasets diff --git a/jax/experimental/stax.py b/jax/experimental/stax.py index 373f9f1b7..311c9d3f5 100644 --- a/jax/experimental/stax.py +++ b/jax/experimental/stax.py @@ -31,7 +31,7 @@ from six.moves import reduce from jax import lax from jax import random -from jax.scipy.misc import logsumexp +from jax.scipy.special import logsumexp import jax.numpy as np diff --git a/jax/scipy/misc.py b/jax/scipy/misc.py index 57c237c4c..4bd091ba6 100644 --- a/jax/scipy/misc.py +++ b/jax/scipy/misc.py @@ -20,18 +20,9 @@ import numpy as onp import scipy.misc as osp_misc from .. import lax +from ..scipy import special from ..numpy.lax_numpy import _wraps, _reduction_dims, _constant_like -@_wraps(osp_misc.logsumexp) -def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): - if b is not None or return_sign: - raise NotImplementedError("Only implemented for b=None, return_sign=False") - dims = _reduction_dims(a, axis) - shape = lax.subvals(onp.shape(a), zip(dims, (1,) * len(dims))) - dimadd = lambda x: lax.reshape(x, shape) - amax = lax.reduce(a, _constant_like(a, -onp.inf), lax.max, dims) - amax_singletons = dimadd(amax) - out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)), - _constant_like(a, 0), lax.add, dims)), amax) - return dimadd(out) if keepdims else out +if hasattr(osp_misc, 'logsumexp'): + logsumexp = special.logsumexp diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 9667feedd..0ff2634e5 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -16,10 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as onp import scipy.special as osp_special from .. import lax -from ..numpy.lax_numpy import _wraps, asarray +from ..numpy.lax_numpy import _wraps, asarray, _reduction_dims, _constant_like # need to create new functions because _wraps sets the __name__ attribute @@ -41,3 +42,17 @@ def expit(x): x = asarray(x) one = lax._const(x, 1) return lax.div(one, lax.add(one, lax.exp(lax.neg(x)))) + + +@_wraps(osp_special.logsumexp) +def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): + if b is not None or return_sign: + raise NotImplementedError("Only implemented for b=None, return_sign=False") + dims = _reduction_dims(a, axis) + shape = lax.subvals(onp.shape(a), zip(dims, (1,) * len(dims))) + dimadd = lambda x: lax.reshape(x, shape) + amax = lax.reduce(a, _constant_like(a, -onp.inf), lax.max, dims) + amax_singletons = dimadd(amax) + out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)), + _constant_like(a, 0), lax.add, dims)), amax) + return dimadd(out) if keepdims else out \ No newline at end of file diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 63f97d224..c002fe057 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -87,10 +87,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase): def testLogSumExp(self, rng, shape, dtype, axis, keepdims): # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): - return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims) + return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): - return lsp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims) + return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)