mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #19302 from carlosgmartin:scipy-stats-sem
PiperOrigin-RevId: 598884144
This commit is contained in:
commit
94b2da6a3b
@ -44,6 +44,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
`from jax.experimental.export import export` you should use now
|
||||
`from jax.experimental import export`. The old way of importing will
|
||||
continue to work for a deprecation period of 3 months.
|
||||
* Added {func}`jax.scipy.stats.sem`.
|
||||
|
||||
* Deprecations & Removals
|
||||
* A number of previously deprecated functions have been removed, following a
|
||||
|
@ -180,6 +180,7 @@ jax.scipy.stats
|
||||
|
||||
mode
|
||||
rankdata
|
||||
sem
|
||||
|
||||
jax.scipy.stats.bernoulli
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -23,7 +23,7 @@ import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import vmap
|
||||
from jax._src.numpy.util import check_arraylike, _wraps
|
||||
from jax._src.numpy.util import check_arraylike, _wraps, promote_args_inexact
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
@ -147,3 +147,20 @@ def rankdata(
|
||||
if method == "average":
|
||||
return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.canonicalize_dtype(jnp.float_))
|
||||
raise ValueError(f"unknown method '{method}'")
|
||||
|
||||
@_wraps(scipy.stats.sem, lax_description="""\
|
||||
Currently the only supported nan_policies are 'propagate' and 'omit'
|
||||
""")
|
||||
@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims'])
|
||||
def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "propagate", *, keepdims: bool = False) -> Array:
|
||||
b, = promote_args_inexact("sem", a)
|
||||
if axis is None:
|
||||
b = b.ravel()
|
||||
axis = 0
|
||||
if nan_policy == "propagate":
|
||||
return b.std(axis, ddof=ddof) / jnp.sqrt(b.shape[axis]).astype(b.dtype)
|
||||
elif nan_policy == "omit":
|
||||
count = (~jnp.isnan(b)).sum(axis)
|
||||
return jnp.nanstd(b, axis, ddof=ddof) / jnp.sqrt(count).astype(b.dtype)
|
||||
else:
|
||||
raise ValueError(f"{nan_policy} is not supported")
|
||||
|
@ -38,6 +38,6 @@ from jax.scipy.stats import betabinom as betabinom
|
||||
from jax.scipy.stats import gennorm as gennorm
|
||||
from jax.scipy.stats import truncnorm as truncnorm
|
||||
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
|
||||
from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata
|
||||
from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata, sem as sem
|
||||
from jax.scipy.stats import vonmises as vonmises
|
||||
from jax.scipy.stats import wrapcauchy as wrapcauchy
|
||||
|
@ -1585,5 +1585,27 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=tol)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy)
|
||||
for shape in [(5,), (5, 6), (5, 6, 7)]
|
||||
for axis in [None, *range(len(shape))]
|
||||
for ddof in [0, 1, 2, 3]
|
||||
for nan_policy in ["propagate", "omit"]
|
||||
],
|
||||
dtype=jtu.dtypes.integer + jtu.dtypes.floating,
|
||||
)
|
||||
def testSEM(self, shape, dtype, axis, ddof, nan_policy):
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy)
|
||||
lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy)
|
||||
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
|
||||
tol = jtu.tolerance(dtype, tol_spec)
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
atol=tol)
|
||||
self._CompileAndCheck(lax_fun, args_maker, atol=tol)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user