mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Merge pull request #22056 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 646118091
This commit is contained in:
commit
e0b2144000
@ -53,7 +53,7 @@ def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keep
|
||||
>>> mode, count
|
||||
(Array(4, dtype=int32), Array(3, dtype=int32))
|
||||
|
||||
For multi dimensional arrays, ``jax.scipy.stats.mode`` compuptes the ``mode``
|
||||
For multi dimensional arrays, ``jax.scipy.stats.mode`` computes the ``mode``
|
||||
and the corresponding ``count`` along ``axis=0``:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
|
||||
@ -235,6 +235,62 @@ def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "pr
|
||||
|
||||
Returns:
|
||||
array
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x)
|
||||
Array(0.41, dtype=float32)
|
||||
|
||||
For multi dimensional arrays, ``sem`` computes standard error of mean along
|
||||
``axis=0``:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
|
||||
... [3, 1, 3, 2, 1, 3],
|
||||
... [1, 2, 2, 3, 1, 2]])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1)
|
||||
Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32)
|
||||
|
||||
If ``axis=1``, standard error of mean will be computed along ``axis 1``.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1, axis=1)
|
||||
Array([0.33, 0.4 , 0.31], dtype=float32)
|
||||
|
||||
If ``axis=None``, standard error of mean will be computed along all the axes.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1, axis=None)
|
||||
Array(0.2, dtype=float32)
|
||||
|
||||
By default, ``sem`` reduces the dimension of the result. To keep the
|
||||
dimensions same as that of the input array, the argument ``keepdims`` must
|
||||
be set to ``True``.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1, axis=1, keepdims=True)
|
||||
Array([[0.33],
|
||||
[0.4 ],
|
||||
[0.31]], dtype=float32)
|
||||
|
||||
Since, by default, ``nan_policy='propagate'``, ``sem`` propagates the ``nan``
|
||||
values in the result.
|
||||
|
||||
>>> nan = jnp.nan
|
||||
>>> x2 = jnp.array([[1, 2, 3, nan, 4, 2],
|
||||
... [4, 5, 4, 3, nan, 1],
|
||||
... [7, nan, 8, 7, 9, nan]])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x2)
|
||||
Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32)
|
||||
|
||||
If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error
|
||||
for the remainging values along the specified axis.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x2, nan_policy='omit')
|
||||
Array([1.73, 1.5 , 1.53, 2. , 2.5 , 0.5 ], dtype=float32)
|
||||
"""
|
||||
b, = promote_args_inexact("sem", a)
|
||||
if nan_policy == "propagate":
|
||||
|
Loading…
x
Reference in New Issue
Block a user