diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py new file mode 100644 index 000000000..2e39def76 --- /dev/null +++ b/jax/_src/scipy/stats/_core.py @@ -0,0 +1,83 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from functools import partial +from typing import Optional, Tuple + +import jax.numpy as jnp +import scipy +from jax import jit +from jax._src import dtypes +from jax._src.api import vmap +from jax._src.numpy.lax_numpy import _check_arraylike +from jax._src.numpy.util import _wraps +from jax._src.typing import ArrayLike +from jax._src.util import canonicalize_axis, prod + +ModeResult = namedtuple('ModeResult', ('mode', 'count')) + +@_wraps(scipy.stats.mode, lax_description="""\ +Currently the only supported nan_policy is 'propagate' +""") +@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) +def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult: + _check_arraylike("mode", a) + x = jnp.atleast_1d(a) + + if nan_policy not in ["propagate", "omit", "raise"]: + raise ValueError( + f"Illegal nan_policy value {nan_policy!r}; expected one of " + "{'propoagate', 'omit', 'raise'}" + ) + if nan_policy == "omit": + # TODO: return answer without nans included. + raise NotImplementedError( + f"Logic for `nan_policy` of {nan_policy} is not implemented" + ) + if nan_policy == "raise": + raise NotImplementedError( + "In order to best JIT compile `mode`, we cannot know whether `x` contains nans. " + "Please check if nans exist in `x` outside of the `mode` function." + ) + + input_shape = x.shape + if keepdims: + if axis is None: + output_shape = tuple(1 for i in input_shape) + else: + output_shape = tuple(1 if i == axis else s for i, s in enumerate(input_shape)) + else: + if axis is None: + output_shape = () + else: + output_shape = tuple(s for i, s in enumerate(input_shape) if i != axis) + + if axis is None: + axis = 0 + x = x.ravel() + + def _mode_helper(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Helper function to return mode and count of a given array.""" + if x.size == 0: + return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)) + else: + vals, counts = jnp.unique(x, return_counts=True, size=x.size) + return vals[jnp.argmax(counts)], counts.max() + + axis = canonicalize_axis(axis, x.ndim) + x = jnp.moveaxis(x, axis, 0) + x = x.reshape(x.shape[0], prod(x.shape[1:])) + vals, counts = vmap(_mode_helper, in_axes=1)(x) + return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape)) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index c57bcbb51..1f71491e9 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -33,3 +33,4 @@ from jax.scipy.stats import chi2 as chi2 from jax.scipy.stats import betabinom as betabinom from jax.scipy.stats import gennorm as gennorm from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde +from jax._src.scipy.stats._core import mode as mode diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 04bec1324..cf8715f04 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -20,15 +20,19 @@ from absl.testing import absltest import numpy as np import scipy.stats as osp_stats +import scipy.version import jax -from jax._src import test_util as jtu, tree_util +from jax._src import dtypes, test_util as jtu, tree_util from jax.scipy import stats as lsp_stats from jax.scipy.special import expit from jax.config import config config.parse_flags_with_absl() +scipy_version = tuple(map(int, scipy.version.version.split('.')[:3])) +numpy_version = tuple(map(int, np.version.version.split('.')[:3])) + all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] @@ -880,5 +884,73 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2) self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x)) + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape, axis in ( + ((0,), None), + ((0,), 0), + ((7,), None), + ((7,), 0), + ((47, 8), None), + ((47, 8), 0), + ((47, 8), 1), + ((0, 2, 3), None), + ((0, 2, 3), 0), + ((0, 2, 3), 1), + ((0, 2, 3), 2), + ((10, 5, 21), None), + ((10, 5, 21), 0), + ((10, 5, 21), 1), + ((10, 5, 21), 2), + ) + ], + dtype=jtu.dtypes.integer + jtu.dtypes.floating, + contains_nans=[True, False], + keepdims=[True, False] + ) + def testMode(self, shape, dtype, axis, contains_nans, keepdims): + if scipy_version < (1, 9, 0) and keepdims != True: + self.skipTest("scipy < 1.9.0 only support keepdims == True") + if numpy_version < (1, 21, 0) and contains_nans: + self.skipTest("numpy < 1.21.0 only support contains_nans == False") + + if contains_nans: + rng = jtu.rand_some_nan(self.rng()) + else: + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None): + """Wrapper to manage the shape discrepancies between scipy and jax""" + if scipy_version < (1, 9, 0) and a.size == 0 and keepdims == True: + if axis == None: + output_shape = tuple(1 for _ in a.shape) + else: + output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape)) + return (np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)), + np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_))) + + if scipy_version < (1, 9, 0): + result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy) + else: + result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims) + + if a.size != 0 and axis == None and keepdims == True: + output_shape = tuple(1 for _ in a.shape) + return (result.mode.reshape(output_shape), result.count.reshape(output_shape)) + return result + + scipy_fun = partial(scipy_mode_wrapper, axis=axis, keepdims=keepdims) + scipy_fun = jtu.ignore_warning(category=RuntimeWarning, + message="Mean of empty slice.*")(scipy_fun) + scipy_fun = jtu.ignore_warning(category=RuntimeWarning, + message="invalid value encountered.*")(scipy_fun) + lax_fun = partial(lsp_stats.mode, axis=axis, keepdims=keepdims) + tol_spec = {np.float32: 2e-4, np.float64: 5e-6} + tol = jtu.tolerance(dtype, tol_spec) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(lax_fun, args_maker, rtol=tol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())