mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jax.random: warn on unsupported dtypes
This commit is contained in:
parent
faa7a68422
commit
f0309b49c9
@ -325,7 +325,7 @@ def bits(key: KeyArray,
|
||||
|
||||
def uniform(key: KeyArray,
|
||||
shape: Union[Shape, NamedShape] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_,
|
||||
dtype: DTypeLikeFloat = float,
|
||||
minval: RealArray = 0.,
|
||||
maxval: RealArray = 1.) -> Array:
|
||||
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
||||
@ -343,6 +343,8 @@ def uniform(key: KeyArray,
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `uniform` must be a float dtype, "
|
||||
f"got {dtype}")
|
||||
@ -393,7 +395,7 @@ def randint(key: KeyArray,
|
||||
shape: Shape,
|
||||
minval: IntegerArray,
|
||||
maxval: IntegerArray,
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
||||
dtype: DTypeLikeInt = int) -> Array:
|
||||
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
||||
|
||||
Args:
|
||||
@ -410,6 +412,7 @@ def randint(key: KeyArray,
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _randint(key, shape, minval, maxval, dtype)
|
||||
@ -633,7 +636,7 @@ def choice(key: KeyArray,
|
||||
|
||||
def normal(key: KeyArray,
|
||||
shape: Union[Shape, NamedShape] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample standard normal random values with given shape and float dtype.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
@ -654,6 +657,7 @@ def normal(key: KeyArray,
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.inexact):
|
||||
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
|
||||
f"got {dtype}")
|
||||
@ -720,6 +724,7 @@ def multivariate_normal(key: KeyArray,
|
||||
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
mean, cov = promote_dtypes_inexact(mean, cov)
|
||||
if method not in {'svd', 'eigh', 'cholesky'}:
|
||||
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
|
||||
@ -769,7 +774,7 @@ def truncated_normal(key: KeyArray,
|
||||
lower: RealArray,
|
||||
upper: RealArray,
|
||||
shape: Optional[Union[Shape, NamedShape]] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample truncated standard normal random values with given shape and dtype.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
@ -798,6 +803,7 @@ def truncated_normal(key: KeyArray,
|
||||
Returns values in the open interval ``(lower, upper)``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -879,7 +885,7 @@ def beta(key: KeyArray,
|
||||
a: RealArray,
|
||||
b: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Beta random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -906,6 +912,7 @@ def beta(key: KeyArray,
|
||||
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `beta` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -937,7 +944,7 @@ def _beta(key, a, b, shape, dtype) -> Array:
|
||||
|
||||
def cauchy(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Cauchy random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -958,6 +965,7 @@ def cauchy(key: KeyArray,
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `cauchy` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -976,7 +984,7 @@ def _cauchy(key, shape, dtype) -> Array:
|
||||
def dirichlet(key: KeyArray,
|
||||
alpha: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Dirichlet random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according the the probability density function:
|
||||
@ -1009,6 +1017,7 @@ def dirichlet(key: KeyArray,
|
||||
``alpha.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `dirichlet` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1046,7 +1055,7 @@ def _softmax(x, axis) -> Array:
|
||||
|
||||
def exponential(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Exponential random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according the the probability density function:
|
||||
@ -1067,6 +1076,7 @@ def exponential(key: KeyArray,
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `exponential` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1213,7 +1223,7 @@ batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
|
||||
def gamma(key: KeyArray,
|
||||
a: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Gamma random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according the the probability density function:
|
||||
@ -1247,6 +1257,7 @@ def gamma(key: KeyArray,
|
||||
accuracy for small values of ``a``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `gamma` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1259,7 +1270,7 @@ def gamma(key: KeyArray,
|
||||
def loggamma(key: KeyArray,
|
||||
a: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
"""Sample log-gamma random values with given shape and float dtype.
|
||||
|
||||
This function is implemented such that the following will hold for a
|
||||
@ -1288,6 +1299,7 @@ def loggamma(key: KeyArray,
|
||||
gamma : standard gamma sampler.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `gamma` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1400,7 +1412,7 @@ def _poisson(key, lam, shape, dtype) -> Array:
|
||||
def poisson(key: KeyArray,
|
||||
lam: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
||||
dtype: DTypeLikeInt = int) -> Array:
|
||||
r"""Sample Poisson random values with given shape and integer dtype.
|
||||
|
||||
The values are distributed according to the probability mass function:
|
||||
@ -1423,6 +1435,7 @@ def poisson(key: KeyArray,
|
||||
``shape is not None, or else by ``lam.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
# TODO(frostig): generalize underlying poisson implementation and
|
||||
# remove this check
|
||||
key_impl = key.dtype.impl # type: ignore[union-attr]
|
||||
@ -1442,7 +1455,7 @@ def poisson(key: KeyArray,
|
||||
|
||||
def gumbel(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
"""Sample Gumbel random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1461,6 +1474,7 @@ def gumbel(key: KeyArray,
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `gumbel` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1519,7 +1533,7 @@ def categorical(key: KeyArray,
|
||||
|
||||
def laplace(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Laplace random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1538,6 +1552,7 @@ def laplace(key: KeyArray,
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `laplace` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1555,7 +1570,7 @@ def _laplace(key, shape, dtype) -> Array:
|
||||
|
||||
def logistic(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample logistic random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1574,6 +1589,7 @@ def logistic(key: KeyArray,
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `logistic` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1591,7 +1607,7 @@ def _logistic(key, shape, dtype):
|
||||
def pareto(key: KeyArray,
|
||||
b: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Pareto random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1616,6 +1632,7 @@ def pareto(key: KeyArray,
|
||||
``shape`` is not None, or else by ``b.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `pareto` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1639,7 +1656,7 @@ def _pareto(key, b, shape, dtype) -> Array:
|
||||
def t(key: KeyArray,
|
||||
df: RealArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Student's t random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1664,6 +1681,7 @@ def t(key: KeyArray,
|
||||
``shape`` is not None, or else by ``df.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `t` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1690,7 +1708,7 @@ def _t(key, df, shape, dtype) -> Array:
|
||||
def chisquare(key: KeyArray,
|
||||
df: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Chisquare random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1716,6 +1734,7 @@ def chisquare(key: KeyArray,
|
||||
``shape`` is not None, or else by ``df.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `chisquare` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1742,7 +1761,7 @@ def f(key: KeyArray,
|
||||
dfnum: RealArray,
|
||||
dfden: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample F-distribution random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1773,6 +1792,7 @@ def f(key: KeyArray,
|
||||
``shape`` is not None, or else by ``df.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `f` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1803,7 +1823,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array:
|
||||
|
||||
def rademacher(key: KeyArray,
|
||||
shape: Shape,
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
||||
dtype: DTypeLikeInt = int) -> Array:
|
||||
r"""Sample from a Rademacher distribution.
|
||||
|
||||
The values are distributed according to the probability mass function:
|
||||
@ -1824,6 +1844,7 @@ def rademacher(key: KeyArray,
|
||||
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _rademacher(key, shape, dtype)
|
||||
@ -1837,7 +1858,7 @@ def _rademacher(key, shape, dtype) -> Array:
|
||||
|
||||
def maxwell(key: KeyArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample from a one sided Maxwell distribution.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1859,6 +1880,7 @@ def maxwell(key: KeyArray,
|
||||
# Generate samples using:
|
||||
# sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1)
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `maxwell` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1878,7 +1900,7 @@ def double_sided_maxwell(key: KeyArray,
|
||||
loc: RealArray,
|
||||
scale: RealArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample from a double sided Maxwell distribution.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1901,6 +1923,7 @@ def double_sided_maxwell(key: KeyArray,
|
||||
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float"
|
||||
f" dtype, got {dtype}")
|
||||
@ -1929,7 +1952,7 @@ def weibull_min(key: KeyArray,
|
||||
scale: RealArray,
|
||||
concentration: RealArray,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample from a Weibull distribution.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
@ -1952,6 +1975,7 @@ def weibull_min(key: KeyArray,
|
||||
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `weibull_min` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -1982,7 +2006,7 @@ def orthogonal(
|
||||
key: KeyArray,
|
||||
n: int,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_
|
||||
dtype: DTypeLikeFloat = float
|
||||
) -> Array:
|
||||
"""Sample uniformly from the orthogonal group O(n).
|
||||
|
||||
@ -1999,6 +2023,7 @@ def orthogonal(
|
||||
A random array of shape `(*shape, n, n)` and specified dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
_check_shape("orthogonal", shape)
|
||||
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
||||
z = normal(key, (*shape, n, n), dtype)
|
||||
@ -2010,7 +2035,7 @@ def generalized_normal(
|
||||
key: KeyArray,
|
||||
p: float,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_
|
||||
dtype: DTypeLikeFloat = float
|
||||
) -> Array:
|
||||
r"""Sample from the generalized normal distribution.
|
||||
|
||||
@ -2033,6 +2058,7 @@ def generalized_normal(
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
_check_shape("generalized_normal", shape)
|
||||
keys = split(key)
|
||||
g = gamma(keys[0], 1/p, shape, dtype)
|
||||
@ -2044,7 +2070,7 @@ def ball(
|
||||
d: int,
|
||||
p: float = 2,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_
|
||||
dtype: DTypeLikeFloat = float
|
||||
):
|
||||
"""Sample uniformly from the unit Lp ball.
|
||||
|
||||
@ -2062,6 +2088,7 @@ def ball(
|
||||
A random array of shape `(*shape, d)` and specified dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
_check_shape("ball", shape)
|
||||
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
|
||||
k1, k2 = split(key)
|
||||
@ -2073,7 +2100,7 @@ def ball(
|
||||
def rayleigh(key: KeyArray,
|
||||
scale: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Rayleigh random values with given shape and float dtype.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
@ -2099,6 +2126,7 @@ def rayleigh(key: KeyArray,
|
||||
``shape`` is not None, or else by ``scale.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `rayleigh` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -2125,7 +2153,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
|
||||
def wald(key: KeyArray,
|
||||
mean: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Wald random values with given shape and float dtype.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
@ -2152,6 +2180,7 @@ def wald(key: KeyArray,
|
||||
``shape`` is not None, or else by ``mean.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `wald` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
@ -2182,7 +2211,7 @@ def _wald(key, mean, shape, dtype) -> Array:
|
||||
def geometric(key: KeyArray,
|
||||
p: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
||||
dtype: DTypeLikeInt = int) -> Array:
|
||||
r"""Sample Geometric random values with given shape and float dtype.
|
||||
|
||||
The values are returned according to the probability mass function:
|
||||
@ -2207,6 +2236,7 @@ def geometric(key: KeyArray,
|
||||
``shape`` is not None, or else by ``p.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.integer):
|
||||
raise ValueError("dtype argument to `geometric` must be an int "
|
||||
f"dtype, got {dtype}")
|
||||
@ -2236,7 +2266,7 @@ def triangular(key: KeyArray,
|
||||
mode: RealArray,
|
||||
right: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Triangular random values with given shape and float dtype.
|
||||
|
||||
The values are returned according to the probability density function:
|
||||
@ -2299,7 +2329,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array:
|
||||
def lognormal(key: KeyArray,
|
||||
sigma: RealArray = np.float32(1),
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r""" Sample lognormal random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according to the probability density function:
|
||||
|
@ -1085,7 +1085,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
lam=[0.5, 3, 9, 11, 50, 500],
|
||||
dtype=[np.int16, np.int32, np.int64],
|
||||
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]),
|
||||
)
|
||||
def testPoisson(self, lam, dtype):
|
||||
key = self.make_key(0)
|
||||
@ -1662,8 +1662,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
|
||||
|
||||
@jtu.sample_product(
|
||||
p= [0.2, 0.3, 0.4, 0.5 ,0.6],
|
||||
dtype= [np.int16, np.int32, np.int64])
|
||||
p=[0.2, 0.3, 0.4, 0.5 ,0.6],
|
||||
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]))
|
||||
def testGeometric(self, p, dtype):
|
||||
key = self.make_key(1)
|
||||
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)
|
||||
|
@ -112,6 +112,8 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
self.assertEqual(x32.result(), jnp.int32)
|
||||
|
||||
@jax.legacy_prng_key('allow')
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Explicitly requested dtype float64 is not available")
|
||||
def test_jit_cache(self):
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.skipTest("64-bit random not available on TPU")
|
||||
|
Loading…
x
Reference in New Issue
Block a user