mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #12683 from ylamidon:add-scipy-stats-mode
PiperOrigin-RevId: 482025648
This commit is contained in:
commit
a168c2d5b5
83
jax/_src/scipy/stats/_core.py
Normal file
83
jax/_src/scipy/stats/_core.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import jax.numpy as jnp
|
||||
import scipy
|
||||
from jax import jit
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import vmap
|
||||
from jax._src.numpy.lax_numpy import _check_arraylike
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax._src.util import canonicalize_axis, prod
|
||||
|
||||
ModeResult = namedtuple('ModeResult', ('mode', 'count'))
|
||||
|
||||
@_wraps(scipy.stats.mode, lax_description="""\
|
||||
Currently the only supported nan_policy is 'propagate'
|
||||
""")
|
||||
@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims'])
|
||||
def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult:
|
||||
_check_arraylike("mode", a)
|
||||
x = jnp.atleast_1d(a)
|
||||
|
||||
if nan_policy not in ["propagate", "omit", "raise"]:
|
||||
raise ValueError(
|
||||
f"Illegal nan_policy value {nan_policy!r}; expected one of "
|
||||
"{'propoagate', 'omit', 'raise'}"
|
||||
)
|
||||
if nan_policy == "omit":
|
||||
# TODO: return answer without nans included.
|
||||
raise NotImplementedError(
|
||||
f"Logic for `nan_policy` of {nan_policy} is not implemented"
|
||||
)
|
||||
if nan_policy == "raise":
|
||||
raise NotImplementedError(
|
||||
"In order to best JIT compile `mode`, we cannot know whether `x` contains nans. "
|
||||
"Please check if nans exist in `x` outside of the `mode` function."
|
||||
)
|
||||
|
||||
input_shape = x.shape
|
||||
if keepdims:
|
||||
if axis is None:
|
||||
output_shape = tuple(1 for i in input_shape)
|
||||
else:
|
||||
output_shape = tuple(1 if i == axis else s for i, s in enumerate(input_shape))
|
||||
else:
|
||||
if axis is None:
|
||||
output_shape = ()
|
||||
else:
|
||||
output_shape = tuple(s for i, s in enumerate(input_shape) if i != axis)
|
||||
|
||||
if axis is None:
|
||||
axis = 0
|
||||
x = x.ravel()
|
||||
|
||||
def _mode_helper(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
"""Helper function to return mode and count of a given array."""
|
||||
if x.size == 0:
|
||||
return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_))
|
||||
else:
|
||||
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
|
||||
return vals[jnp.argmax(counts)], counts.max()
|
||||
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
x = jnp.moveaxis(x, axis, 0)
|
||||
x = x.reshape(x.shape[0], prod(x.shape[1:]))
|
||||
vals, counts = vmap(_mode_helper, in_axes=1)(x)
|
||||
return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape))
|
@ -33,3 +33,4 @@ from jax.scipy.stats import chi2 as chi2
|
||||
from jax.scipy.stats import betabinom as betabinom
|
||||
from jax.scipy.stats import gennorm as gennorm
|
||||
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
|
||||
from jax._src.scipy.stats._core import mode as mode
|
||||
|
@ -20,15 +20,19 @@ from absl.testing import absltest
|
||||
|
||||
import numpy as np
|
||||
import scipy.stats as osp_stats
|
||||
import scipy.version
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu, tree_util
|
||||
from jax._src import dtypes, test_util as jtu, tree_util
|
||||
from jax.scipy import stats as lsp_stats
|
||||
from jax.scipy.special import expit
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
|
||||
numpy_version = tuple(map(int, np.version.version.split('.')[:3]))
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
|
||||
|
||||
@ -880,5 +884,73 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
|
||||
self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, axis=axis)
|
||||
for shape, axis in (
|
||||
((0,), None),
|
||||
((0,), 0),
|
||||
((7,), None),
|
||||
((7,), 0),
|
||||
((47, 8), None),
|
||||
((47, 8), 0),
|
||||
((47, 8), 1),
|
||||
((0, 2, 3), None),
|
||||
((0, 2, 3), 0),
|
||||
((0, 2, 3), 1),
|
||||
((0, 2, 3), 2),
|
||||
((10, 5, 21), None),
|
||||
((10, 5, 21), 0),
|
||||
((10, 5, 21), 1),
|
||||
((10, 5, 21), 2),
|
||||
)
|
||||
],
|
||||
dtype=jtu.dtypes.integer + jtu.dtypes.floating,
|
||||
contains_nans=[True, False],
|
||||
keepdims=[True, False]
|
||||
)
|
||||
def testMode(self, shape, dtype, axis, contains_nans, keepdims):
|
||||
if scipy_version < (1, 9, 0) and keepdims != True:
|
||||
self.skipTest("scipy < 1.9.0 only support keepdims == True")
|
||||
if numpy_version < (1, 21, 0) and contains_nans:
|
||||
self.skipTest("numpy < 1.21.0 only support contains_nans == False")
|
||||
|
||||
if contains_nans:
|
||||
rng = jtu.rand_some_nan(self.rng())
|
||||
else:
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None):
|
||||
"""Wrapper to manage the shape discrepancies between scipy and jax"""
|
||||
if scipy_version < (1, 9, 0) and a.size == 0 and keepdims == True:
|
||||
if axis == None:
|
||||
output_shape = tuple(1 for _ in a.shape)
|
||||
else:
|
||||
output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape))
|
||||
return (np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)),
|
||||
np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)))
|
||||
|
||||
if scipy_version < (1, 9, 0):
|
||||
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy)
|
||||
else:
|
||||
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims)
|
||||
|
||||
if a.size != 0 and axis == None and keepdims == True:
|
||||
output_shape = tuple(1 for _ in a.shape)
|
||||
return (result.mode.reshape(output_shape), result.count.reshape(output_shape))
|
||||
return result
|
||||
|
||||
scipy_fun = partial(scipy_mode_wrapper, axis=axis, keepdims=keepdims)
|
||||
scipy_fun = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Mean of empty slice.*")(scipy_fun)
|
||||
scipy_fun = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="invalid value encountered.*")(scipy_fun)
|
||||
lax_fun = partial(lsp_stats.mode, axis=axis, keepdims=keepdims)
|
||||
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
|
||||
tol = jtu.tolerance(dtype, tol_spec)
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user