Merge pull request #19302 from carlosgmartin:scipy-stats-sem

PiperOrigin-RevId: 598884144
This commit is contained in:
jax authors 2024-01-16 10:34:45 -08:00
commit 94b2da6a3b
5 changed files with 43 additions and 2 deletions

View File

@ -44,6 +44,7 @@ Remember to align the itemized text with the first line of an item within a list
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will
continue to work for a deprecation period of 3 months.
* Added {func}`jax.scipy.stats.sem`.
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a

View File

@ -180,6 +180,7 @@ jax.scipy.stats
mode
rankdata
sem
jax.scipy.stats.bernoulli
~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -23,7 +23,7 @@ import jax.numpy as jnp
from jax import jit
from jax._src import dtypes
from jax._src.api import vmap
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.numpy.util import check_arraylike, _wraps, promote_args_inexact
from jax._src.typing import ArrayLike, Array
from jax._src.util import canonicalize_axis
@ -147,3 +147,20 @@ def rankdata(
if method == "average":
return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.canonicalize_dtype(jnp.float_))
raise ValueError(f"unknown method '{method}'")
@_wraps(scipy.stats.sem, lax_description="""\
Currently the only supported nan_policies are 'propagate' and 'omit'
""")
@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims'])
def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "propagate", *, keepdims: bool = False) -> Array:
b, = promote_args_inexact("sem", a)
if axis is None:
b = b.ravel()
axis = 0
if nan_policy == "propagate":
return b.std(axis, ddof=ddof) / jnp.sqrt(b.shape[axis]).astype(b.dtype)
elif nan_policy == "omit":
count = (~jnp.isnan(b)).sum(axis)
return jnp.nanstd(b, axis, ddof=ddof) / jnp.sqrt(count).astype(b.dtype)
else:
raise ValueError(f"{nan_policy} is not supported")

View File

@ -38,6 +38,6 @@ from jax.scipy.stats import betabinom as betabinom
from jax.scipy.stats import gennorm as gennorm
from jax.scipy.stats import truncnorm as truncnorm
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata
from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata, sem as sem
from jax.scipy.stats import vonmises as vonmises
from jax.scipy.stats import wrapcauchy as wrapcauchy

View File

@ -1585,5 +1585,27 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=tol)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol)
@jtu.sample_product(
[dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy)
for shape in [(5,), (5, 6), (5, 6, 7)]
for axis in [None, *range(len(shape))]
for ddof in [0, 1, 2, 3]
for nan_policy in ["propagate", "omit"]
],
dtype=jtu.dtypes.integer + jtu.dtypes.floating,
)
def testSEM(self, shape, dtype, axis, ddof, nan_policy):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy)
lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy)
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
tol = jtu.tolerance(dtype, tol_spec)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
atol=tol)
self._CompileAndCheck(lax_fun, args_maker, atol=tol)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())