Add JAX equivalent of scipy.stats.mode

This commit is contained in:
Yann Lamidon 2022-10-06 10:19:44 +01:00
parent 66af016df3
commit ccbc3059b0
3 changed files with 157 additions and 1 deletions

View File

@ -0,0 +1,83 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial
from typing import Optional, Tuple
import jax.numpy as jnp
import scipy
from jax import jit
from jax._src import dtypes
from jax._src.api import vmap
from jax._src.numpy.lax_numpy import _check_arraylike
from jax._src.numpy.util import _wraps
from jax._src.typing import ArrayLike
from jax._src.util import canonicalize_axis, prod
ModeResult = namedtuple('ModeResult', ('mode', 'count'))
@_wraps(scipy.stats.mode, lax_description="""\
Currently the only supported nan_policy is 'propagate'
""")
@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims'])
def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult:
_check_arraylike("mode", a)
x = jnp.atleast_1d(a)
if nan_policy not in ["propagate", "omit", "raise"]:
raise ValueError(
f"Illegal nan_policy value {nan_policy!r}; expected one of "
"{'propoagate', 'omit', 'raise'}"
)
if nan_policy == "omit":
# TODO: return answer without nans included.
raise NotImplementedError(
f"Logic for `nan_policy` of {nan_policy} is not implemented"
)
if nan_policy == "raise":
raise NotImplementedError(
"In order to best JIT compile `mode`, we cannot know whether `x` contains nans. "
"Please check if nans exist in `x` outside of the `mode` function."
)
input_shape = x.shape
if keepdims:
if axis is None:
output_shape = tuple(1 for i in input_shape)
else:
output_shape = tuple(1 if i == axis else s for i, s in enumerate(input_shape))
else:
if axis is None:
output_shape = ()
else:
output_shape = tuple(s for i, s in enumerate(input_shape) if i != axis)
if axis is None:
axis = 0
x = x.ravel()
def _mode_helper(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Helper function to return mode and count of a given array."""
if x.size == 0:
return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_))
else:
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
return vals[jnp.argmax(counts)], counts.max()
axis = canonicalize_axis(axis, x.ndim)
x = jnp.moveaxis(x, axis, 0)
x = x.reshape(x.shape[0], prod(x.shape[1:]))
vals, counts = vmap(_mode_helper, in_axes=1)(x)
return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape))

View File

@ -33,3 +33,4 @@ from jax.scipy.stats import chi2 as chi2
from jax.scipy.stats import betabinom as betabinom
from jax.scipy.stats import gennorm as gennorm
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
from jax._src.scipy.stats._core import mode as mode

View File

@ -20,15 +20,19 @@ from absl.testing import absltest
import numpy as np
import scipy.stats as osp_stats
import scipy.version
import jax
from jax._src import test_util as jtu, tree_util
from jax._src import dtypes, test_util as jtu, tree_util
from jax.scipy import stats as lsp_stats
from jax.scipy.special import expit
from jax.config import config
config.parse_flags_with_absl()
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
numpy_version = tuple(map(int, np.version.version.split('.')[:3]))
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
@ -880,5 +884,73 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape, axis in (
((0,), None),
((0,), 0),
((7,), None),
((7,), 0),
((47, 8), None),
((47, 8), 0),
((47, 8), 1),
((0, 2, 3), None),
((0, 2, 3), 0),
((0, 2, 3), 1),
((0, 2, 3), 2),
((10, 5, 21), None),
((10, 5, 21), 0),
((10, 5, 21), 1),
((10, 5, 21), 2),
)
],
dtype=jtu.dtypes.integer + jtu.dtypes.floating,
contains_nans=[True, False],
keepdims=[True, False]
)
def testMode(self, shape, dtype, axis, contains_nans, keepdims):
if scipy_version < (1, 9, 0) and keepdims != True:
self.skipTest("scipy < 1.9.0 only support keepdims == True")
if numpy_version < (1, 21, 0) and contains_nans:
self.skipTest("numpy < 1.21.0 only support contains_nans == False")
if contains_nans:
rng = jtu.rand_some_nan(self.rng())
else:
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None):
"""Wrapper to manage the shape discrepancies between scipy and jax"""
if scipy_version < (1, 9, 0) and a.size == 0 and keepdims == True:
if axis == None:
output_shape = tuple(1 for _ in a.shape)
else:
output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape))
return (np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)),
np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)))
if scipy_version < (1, 9, 0):
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy)
else:
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims)
if a.size != 0 and axis == None and keepdims == True:
output_shape = tuple(1 for _ in a.shape)
return (result.mode.reshape(output_shape), result.count.reshape(output_shape))
return result
scipy_fun = partial(scipy_mode_wrapper, axis=axis, keepdims=keepdims)
scipy_fun = jtu.ignore_warning(category=RuntimeWarning,
message="Mean of empty slice.*")(scipy_fun)
scipy_fun = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value encountered.*")(scipy_fun)
lax_fun = partial(lsp_stats.mode, axis=axis, keepdims=keepdims)
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
tol = jtu.tolerance(dtype, tol_spec)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())