mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15140 from jakevdp:doc-random
PiperOrigin-RevId: 518650747
This commit is contained in:
commit
68503629b7
@ -552,7 +552,14 @@ def choice(key: KeyArray,
|
||||
def normal(key: KeyArray,
|
||||
shape: Union[Shape, NamedShape] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample standard normal random values with given shape and float dtype.
|
||||
r"""Sample standard normal random values with given shape and float dtype.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
|
||||
|
||||
on the domain :math:`-\infty < x < \infty`
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -600,7 +607,15 @@ def multivariate_normal(key: KeyArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = None,
|
||||
method: str = 'cholesky') -> Array:
|
||||
"""Sample multivariate normal random values with given mean and covariance.
|
||||
r"""Sample multivariate normal random values with given mean and covariance.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}
|
||||
|
||||
where :math:`k` is the dimension, :math:`\mu` is the mean (given by ``mean``) and
|
||||
:math:`\Sigma` is the covariance matrix (given by ``cov``).
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -673,7 +688,14 @@ def truncated_normal(key: KeyArray,
|
||||
upper: RealArray,
|
||||
shape: Optional[Union[Shape, NamedShape]] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample truncated standard normal random values with given shape and dtype.
|
||||
r"""Sample truncated standard normal random values with given shape and dtype.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x) \propto e^{-x^2/2}
|
||||
|
||||
on the domain :math:`\rm{lower} < x < \rm{upper}`.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -729,7 +751,14 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
|
||||
def bernoulli(key: KeyArray,
|
||||
p: RealArray = np.float32(0.5),
|
||||
shape: Optional[Union[Shape, NamedShape]] = None) -> Array:
|
||||
"""Sample Bernoulli random values with given shape and mean.
|
||||
r"""Sample Bernoulli random values with given shape and mean.
|
||||
|
||||
The values are distributed according to the probability mass function:
|
||||
|
||||
.. math::
|
||||
f(k; p) = p^k(1 - p)^{1 - k}
|
||||
|
||||
where :math:`k \in \{0, 1\}` and :math:`0 \le p \le 1`.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -769,7 +798,14 @@ def beta(key: KeyArray,
|
||||
b: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Beta random values with given shape and float dtype.
|
||||
r"""Sample Beta random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x;a,b) \propto x^{a - 1}(1 - x)^{b - 1}
|
||||
|
||||
on the domain :math:`0 \le x \le 1`.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -820,7 +856,14 @@ def _beta(key, a, b, shape, dtype) -> Array:
|
||||
def cauchy(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Cauchy random values with given shape and float dtype.
|
||||
r"""Sample Cauchy random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x) \propto \frac{1}{x^2 + 1}
|
||||
|
||||
on the domain :math:`-\infty < x < \infty`
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -852,7 +895,19 @@ def dirichlet(key: KeyArray,
|
||||
alpha: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Dirichlet random values with given shape and float dtype.
|
||||
r"""Sample Dirichlet random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according the the probability density function:
|
||||
|
||||
.. math::
|
||||
f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i}
|
||||
|
||||
Where :math:`k` is the dimension, and :math:`\{x_i\}` satisfies
|
||||
|
||||
.. math::
|
||||
\sum_{i=1}^k x_i = 1
|
||||
|
||||
and :math:`0 \le x_i \le 1` for all :math:`x_i`.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -910,7 +965,14 @@ def _softmax(x, axis) -> Array:
|
||||
def exponential(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Exponential random values with given shape and float dtype.
|
||||
r"""Sample Exponential random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according the the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x) = e^{-x}
|
||||
|
||||
on the domain :math:`0 \le x < \infty`.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -1074,9 +1136,16 @@ def gamma(key: KeyArray,
|
||||
a: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Gamma random values with given shape and float dtype.
|
||||
r"""Sample Gamma random values with given shape and float dtype.
|
||||
|
||||
This implements the standard gamma density, with a unit scale/rate parameter.
|
||||
The values are distributed according the the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x;a) \propto x^{a - 1} e^{-x}
|
||||
|
||||
on the domain :math:`0 \le x < \infty`, with :math:`a > 0`.
|
||||
|
||||
This is the standard gamma density, with a unit scale/rate parameter.
|
||||
Dividing the sample output by the rate is equivalent to sampling from
|
||||
*gamma(a, rate)*, and multiplying the sample output by the scale is equivalent
|
||||
to sampling from *gamma(a, scale)*.
|
||||
@ -1254,7 +1323,14 @@ def poisson(key: KeyArray,
|
||||
lam: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
||||
"""Sample Poisson random values with given shape and integer dtype.
|
||||
r"""Sample Poisson random values with given shape and integer dtype.
|
||||
|
||||
The values are distributed according to the probability mass function:
|
||||
|
||||
.. math::
|
||||
f(k; \lambda) = \frac{\lambda^k e^{-\lambda}}{k!}
|
||||
|
||||
Where `k` is a non-negative integer and :math:`\lambda > 0`.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -1291,6 +1367,11 @@ def gumbel(key: KeyArray,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Gumbel random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x) = e^{-(x + e^{-x})}
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
shape: optional, a tuple of nonnegative integers representing the result
|
||||
@ -1361,7 +1442,12 @@ def categorical(key: KeyArray,
|
||||
def laplace(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Laplace random values with given shape and float dtype.
|
||||
r"""Sample Laplace random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x) = \frac{1}{2}e^{-|x|}
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -1392,7 +1478,12 @@ def _laplace(key, shape, dtype) -> Array:
|
||||
def logistic(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample logistic random values with given shape and float dtype.
|
||||
r"""Sample logistic random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -1423,7 +1514,14 @@ def pareto(key: KeyArray,
|
||||
b: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Pareto random values with given shape and float dtype.
|
||||
r"""Sample Pareto random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x; b) = b / x^{b + 1}
|
||||
|
||||
on the domain :math:`0 \le x < \infty` with :math:`b > 0`
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -1464,12 +1562,19 @@ def t(key: KeyArray,
|
||||
df: RealArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Student's t random values with given shape and float dtype.
|
||||
r"""Sample Student's t random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(t; \nu) \propto \left(1 + \frac{t^2}{\nu}\right)^{-(\nu + 1)/2}
|
||||
|
||||
Where :math:`\nu > 0` is the degrees of freedom, given by the parameter ``df``.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
df: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the parameter of the distribution.
|
||||
representing the degrees of freedom parameter of the distribution.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
shape. Must be broadcast-compatible with ``df``. The default (None)
|
||||
produces a result shape equal to ``df.shape``.
|
||||
@ -1508,7 +1613,15 @@ def chisquare(key: KeyArray,
|
||||
df: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Chisquare random values with given shape and float dtype.
|
||||
r"""Sample Chisquare random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x; \nu) \propto x^{k/2 - 1}e^{-x/2}
|
||||
|
||||
on the domain :math:`0 < x < \infty`, where :math:`\nu > 0` represents the
|
||||
degrees of freedom, given by the parameter ``df``.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -1552,7 +1665,17 @@ def f(key: KeyArray,
|
||||
dfden: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample F-distribution random values with given shape and float dtype.
|
||||
r"""Sample F-distribution random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x; \nu) \propto x^{\nu_1/2 - 1}\left(1 + \frac{\nu_1}{\nu_2}x\right)^{
|
||||
-(\nu_1 + \nu_2) / 2}
|
||||
|
||||
on the domain :math:`0 < x < \infty`. Here :math:`\nu_1` is the degrees of
|
||||
freedom of the numerator (``dfnum``), and :math:`\nu_2` is the degrees of
|
||||
freedom of the denominator (``dfden``).
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -1603,7 +1726,14 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array:
|
||||
def rademacher(key: KeyArray,
|
||||
shape: Shape,
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
||||
"""Sample from a Rademacher distribution.
|
||||
r"""Sample from a Rademacher distribution.
|
||||
|
||||
The values are distributed according to the probability mass function:
|
||||
|
||||
.. math::
|
||||
f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1))
|
||||
|
||||
on the domain :math:`k \in \{-1, 1}`, where `\delta(x)` is the dirac delta function.
|
||||
|
||||
Args:
|
||||
key: a PRNG key.
|
||||
@ -1630,9 +1760,14 @@ def _rademacher(key, shape, dtype) -> Array:
|
||||
def maxwell(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample from a one sided Maxwell distribution.
|
||||
r"""Sample from a one sided Maxwell distribution.
|
||||
|
||||
The scipy counterpart is `scipy.stats.maxwell`.
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x) \propto x^2 e^{-x^2 / 2}
|
||||
|
||||
on the domain :math:`0 \le x < \infty`.
|
||||
|
||||
Args:
|
||||
key: a PRNG key.
|
||||
@ -1666,10 +1801,15 @@ def double_sided_maxwell(key: KeyArray,
|
||||
scale: RealArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample from a double sided Maxwell distribution.
|
||||
r"""Sample from a double sided Maxwell distribution.
|
||||
|
||||
Samples using:
|
||||
loc + scale* sgn(U-0.5)* one_sided_maxwell U~Unif;
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x;\mu,\sigma) \propto z^2 e^{-z^2 / 2}
|
||||
|
||||
where :math:`z = (x - \mu) / \sigma`, with the center :math:`\mu` specified by
|
||||
``loc`` and the scale :math:`\sigma` specified by ``scale``.
|
||||
|
||||
Args:
|
||||
key: a PRNG key.
|
||||
@ -1712,9 +1852,15 @@ def weibull_min(key: KeyArray,
|
||||
concentration: RealArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample from a Weibull distribution.
|
||||
r"""Sample from a Weibull distribution.
|
||||
|
||||
The scipy counterpart is `scipy.stats.weibull_min`.
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x;\sigma,c) \propto x^{c - 1} \exp(-(x / \sigma)^c)
|
||||
|
||||
on the domain :math:`0 < x < \infty`, where :math:`c > 0` is the concentration
|
||||
parameter, and :math:`\sigma > 0` is the scale parameter.
|
||||
|
||||
Args:
|
||||
key: a PRNG key.
|
||||
@ -1788,7 +1934,15 @@ def generalized_normal(
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_
|
||||
) -> Array:
|
||||
"""Sample from the generalized normal distribution.
|
||||
r"""Sample from the generalized normal distribution.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x;p) \propto e^{-|x|^p}
|
||||
|
||||
on the domain :math:`-\infty < x < \infty`, where :math:`p > 0` is the
|
||||
shape parameter.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -1842,7 +1996,15 @@ def rayleigh(key: KeyArray,
|
||||
scale: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Rayleigh random values with given shape and float dtype.
|
||||
r"""Sample Rayleigh random values with given shape and float dtype.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x;\sigma) \propto xe^{-x^2/(2\sigma^2)}
|
||||
|
||||
on the domain :math:`-\infty < x < \infty`, and where `\sigma > 0` is the scale
|
||||
parameter of the distribution.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
@ -1885,7 +2047,16 @@ def wald(key: KeyArray,
|
||||
mean: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Wald random values with given shape and float dtype.
|
||||
r"""Sample Wald random values with given shape and float dtype.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x;\mu) = \frac{1}{\sqrt{2\pi x^3}} \exp\left(-\frac{(x - \mu)^2}{2\mu^2 x}\right)
|
||||
|
||||
on the domain :math:`-\infty < x < \infty`, and where :math:`\mu > 0` is the location
|
||||
parameter of the distribution.
|
||||
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
|
Loading…
x
Reference in New Issue
Block a user