diff --git a/CHANGELOG.md b/CHANGELOG.md index a850b8c43..71303d206 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ Remember to align the itemized text with the first line of an item within a list `from jax.experimental.export import export` you should use now `from jax.experimental import export`. The old way of importing will continue to work for a deprecation period of 3 months. + * Added {func}`jax.scipy.stats.sem`. * Deprecations & Removals * A number of previously deprecated functions have been removed, following a diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index aa723ec43..6d10f5071 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -180,6 +180,7 @@ jax.scipy.stats mode rankdata + sem jax.scipy.stats.bernoulli ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 541f409dc..a3b285094 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from jax import jit from jax._src import dtypes from jax._src.api import vmap -from jax._src.numpy.util import check_arraylike, _wraps +from jax._src.numpy.util import check_arraylike, _wraps, promote_args_inexact from jax._src.typing import ArrayLike, Array from jax._src.util import canonicalize_axis @@ -147,3 +147,20 @@ def rankdata( if method == "average": return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.canonicalize_dtype(jnp.float_)) raise ValueError(f"unknown method '{method}'") + +@_wraps(scipy.stats.sem, lax_description="""\ +Currently the only supported nan_policies are 'propagate' and 'omit' +""") +@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) +def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "propagate", *, keepdims: bool = False) -> Array: + b, = promote_args_inexact("sem", a) + if axis is None: + b = b.ravel() + axis = 0 + if nan_policy == "propagate": + return b.std(axis, ddof=ddof) / jnp.sqrt(b.shape[axis]).astype(b.dtype) + elif nan_policy == "omit": + count = (~jnp.isnan(b)).sum(axis) + return jnp.nanstd(b, axis, ddof=ddof) / jnp.sqrt(count).astype(b.dtype) + else: + raise ValueError(f"{nan_policy} is not supported") diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 9458f3b7e..7aa73f7b5 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -38,6 +38,6 @@ from jax.scipy.stats import betabinom as betabinom from jax.scipy.stats import gennorm as gennorm from jax.scipy.stats import truncnorm as truncnorm from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde -from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata +from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata, sem as sem from jax.scipy.stats import vonmises as vonmises from jax.scipy.stats import wrapcauchy as wrapcauchy diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 9d7df674f..a8e51627d 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1585,5 +1585,27 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=tol) self._CompileAndCheck(lax_fun, args_maker, rtol=tol) + @jtu.sample_product( + [dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy) + for shape in [(5,), (5, 6), (5, 6, 7)] + for axis in [None, *range(len(shape))] + for ddof in [0, 1, 2, 3] + for nan_policy in ["propagate", "omit"] + ], + dtype=jtu.dtypes.integer + jtu.dtypes.floating, + ) + def testSEM(self, shape, dtype, axis, ddof, nan_policy): + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy) + lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy) + tol_spec = {np.float32: 2e-4, np.float64: 5e-6} + tol = jtu.tolerance(dtype, tol_spec) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + atol=tol) + self._CompileAndCheck(lax_fun, args_maker, atol=tol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())