jax.random: warn on unsupported dtypes

This commit is contained in:
Jake VanderPlas 2023-08-31 10:56:05 -07:00
parent faa7a68422
commit f0309b49c9
3 changed files with 65 additions and 33 deletions

View File

@ -325,7 +325,7 @@ def bits(key: KeyArray,
def uniform(key: KeyArray,
shape: Union[Shape, NamedShape] = (),
dtype: DTypeLikeFloat = dtypes.float_,
dtype: DTypeLikeFloat = float,
minval: RealArray = 0.,
maxval: RealArray = 1.) -> Array:
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
@ -343,6 +343,8 @@ def uniform(key: KeyArray,
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `uniform` must be a float dtype, "
f"got {dtype}")
@ -393,7 +395,7 @@ def randint(key: KeyArray,
shape: Shape,
minval: IntegerArray,
maxval: IntegerArray,
dtype: DTypeLikeInt = dtypes.int_) -> Array:
dtype: DTypeLikeInt = int) -> Array:
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
@ -410,6 +412,7 @@ def randint(key: KeyArray,
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _randint(key, shape, minval, maxval, dtype)
@ -633,7 +636,7 @@ def choice(key: KeyArray,
def normal(key: KeyArray,
shape: Union[Shape, NamedShape] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample standard normal random values with given shape and float dtype.
The values are returned according to the probability density function:
@ -654,6 +657,7 @@ def normal(key: KeyArray,
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
f"got {dtype}")
@ -720,6 +724,7 @@ def multivariate_normal(key: KeyArray,
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
mean, cov = promote_dtypes_inexact(mean, cov)
if method not in {'svd', 'eigh', 'cholesky'}:
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
@ -769,7 +774,7 @@ def truncated_normal(key: KeyArray,
lower: RealArray,
upper: RealArray,
shape: Optional[Union[Shape, NamedShape]] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample truncated standard normal random values with given shape and dtype.
The values are returned according to the probability density function:
@ -798,6 +803,7 @@ def truncated_normal(key: KeyArray,
Returns values in the open interval ``(lower, upper)``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
f"dtype, got {dtype}")
@ -879,7 +885,7 @@ def beta(key: KeyArray,
a: RealArray,
b: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Beta random values with given shape and float dtype.
The values are distributed according to the probability density function:
@ -906,6 +912,7 @@ def beta(key: KeyArray,
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `beta` must be a float "
f"dtype, got {dtype}")
@ -937,7 +944,7 @@ def _beta(key, a, b, shape, dtype) -> Array:
def cauchy(key: KeyArray,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Cauchy random values with given shape and float dtype.
The values are distributed according to the probability density function:
@ -958,6 +965,7 @@ def cauchy(key: KeyArray,
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `cauchy` must be a float "
f"dtype, got {dtype}")
@ -976,7 +984,7 @@ def _cauchy(key, shape, dtype) -> Array:
def dirichlet(key: KeyArray,
alpha: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Dirichlet random values with given shape and float dtype.
The values are distributed according the the probability density function:
@ -1009,6 +1017,7 @@ def dirichlet(key: KeyArray,
``alpha.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `dirichlet` must be a float "
f"dtype, got {dtype}")
@ -1046,7 +1055,7 @@ def _softmax(x, axis) -> Array:
def exponential(key: KeyArray,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Exponential random values with given shape and float dtype.
The values are distributed according the the probability density function:
@ -1067,6 +1076,7 @@ def exponential(key: KeyArray,
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `exponential` must be a float "
f"dtype, got {dtype}")
@ -1213,7 +1223,7 @@ batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
def gamma(key: KeyArray,
a: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Gamma random values with given shape and float dtype.
The values are distributed according the the probability density function:
@ -1247,6 +1257,7 @@ def gamma(key: KeyArray,
accuracy for small values of ``a``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gamma` must be a float "
f"dtype, got {dtype}")
@ -1259,7 +1270,7 @@ def gamma(key: KeyArray,
def loggamma(key: KeyArray,
a: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
"""Sample log-gamma random values with given shape and float dtype.
This function is implemented such that the following will hold for a
@ -1288,6 +1299,7 @@ def loggamma(key: KeyArray,
gamma : standard gamma sampler.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gamma` must be a float "
f"dtype, got {dtype}")
@ -1400,7 +1412,7 @@ def _poisson(key, lam, shape, dtype) -> Array:
def poisson(key: KeyArray,
lam: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeInt = dtypes.int_) -> Array:
dtype: DTypeLikeInt = int) -> Array:
r"""Sample Poisson random values with given shape and integer dtype.
The values are distributed according to the probability mass function:
@ -1423,6 +1435,7 @@ def poisson(key: KeyArray,
``shape is not None, or else by ``lam.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
# TODO(frostig): generalize underlying poisson implementation and
# remove this check
key_impl = key.dtype.impl # type: ignore[union-attr]
@ -1442,7 +1455,7 @@ def poisson(key: KeyArray,
def gumbel(key: KeyArray,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
"""Sample Gumbel random values with given shape and float dtype.
The values are distributed according to the probability density function:
@ -1461,6 +1474,7 @@ def gumbel(key: KeyArray,
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gumbel` must be a float "
f"dtype, got {dtype}")
@ -1519,7 +1533,7 @@ def categorical(key: KeyArray,
def laplace(key: KeyArray,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Laplace random values with given shape and float dtype.
The values are distributed according to the probability density function:
@ -1538,6 +1552,7 @@ def laplace(key: KeyArray,
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `laplace` must be a float "
f"dtype, got {dtype}")
@ -1555,7 +1570,7 @@ def _laplace(key, shape, dtype) -> Array:
def logistic(key: KeyArray,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample logistic random values with given shape and float dtype.
The values are distributed according to the probability density function:
@ -1574,6 +1589,7 @@ def logistic(key: KeyArray,
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `logistic` must be a float "
f"dtype, got {dtype}")
@ -1591,7 +1607,7 @@ def _logistic(key, shape, dtype):
def pareto(key: KeyArray,
b: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Pareto random values with given shape and float dtype.
The values are distributed according to the probability density function:
@ -1616,6 +1632,7 @@ def pareto(key: KeyArray,
``shape`` is not None, or else by ``b.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `pareto` must be a float "
f"dtype, got {dtype}")
@ -1639,7 +1656,7 @@ def _pareto(key, b, shape, dtype) -> Array:
def t(key: KeyArray,
df: RealArray,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Student's t random values with given shape and float dtype.
The values are distributed according to the probability density function:
@ -1664,6 +1681,7 @@ def t(key: KeyArray,
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `t` must be a float "
f"dtype, got {dtype}")
@ -1690,7 +1708,7 @@ def _t(key, df, shape, dtype) -> Array:
def chisquare(key: KeyArray,
df: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Chisquare random values with given shape and float dtype.
The values are distributed according to the probability density function:
@ -1716,6 +1734,7 @@ def chisquare(key: KeyArray,
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `chisquare` must be a float "
f"dtype, got {dtype}")
@ -1742,7 +1761,7 @@ def f(key: KeyArray,
dfnum: RealArray,
dfden: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample F-distribution random values with given shape and float dtype.
The values are distributed according to the probability density function:
@ -1773,6 +1792,7 @@ def f(key: KeyArray,
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `f` must be a float "
f"dtype, got {dtype}")
@ -1803,7 +1823,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array:
def rademacher(key: KeyArray,
shape: Shape,
dtype: DTypeLikeInt = dtypes.int_) -> Array:
dtype: DTypeLikeInt = int) -> Array:
r"""Sample from a Rademacher distribution.
The values are distributed according to the probability mass function:
@ -1824,6 +1844,7 @@ def rademacher(key: KeyArray,
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _rademacher(key, shape, dtype)
@ -1837,7 +1858,7 @@ def _rademacher(key, shape, dtype) -> Array:
def maxwell(key: KeyArray,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample from a one sided Maxwell distribution.
The values are distributed according to the probability density function:
@ -1859,6 +1880,7 @@ def maxwell(key: KeyArray,
# Generate samples using:
# sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1)
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `maxwell` must be a float "
f"dtype, got {dtype}")
@ -1878,7 +1900,7 @@ def double_sided_maxwell(key: KeyArray,
loc: RealArray,
scale: RealArray,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample from a double sided Maxwell distribution.
The values are distributed according to the probability density function:
@ -1901,6 +1923,7 @@ def double_sided_maxwell(key: KeyArray,
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float"
f" dtype, got {dtype}")
@ -1929,7 +1952,7 @@ def weibull_min(key: KeyArray,
scale: RealArray,
concentration: RealArray,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample from a Weibull distribution.
The values are distributed according to the probability density function:
@ -1952,6 +1975,7 @@ def weibull_min(key: KeyArray,
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `weibull_min` must be a float "
f"dtype, got {dtype}")
@ -1982,7 +2006,7 @@ def orthogonal(
key: KeyArray,
n: int,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_
dtype: DTypeLikeFloat = float
) -> Array:
"""Sample uniformly from the orthogonal group O(n).
@ -1999,6 +2023,7 @@ def orthogonal(
A random array of shape `(*shape, n, n)` and specified dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("orthogonal", shape)
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
z = normal(key, (*shape, n, n), dtype)
@ -2010,7 +2035,7 @@ def generalized_normal(
key: KeyArray,
p: float,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_
dtype: DTypeLikeFloat = float
) -> Array:
r"""Sample from the generalized normal distribution.
@ -2033,6 +2058,7 @@ def generalized_normal(
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("generalized_normal", shape)
keys = split(key)
g = gamma(keys[0], 1/p, shape, dtype)
@ -2044,7 +2070,7 @@ def ball(
d: int,
p: float = 2,
shape: Shape = (),
dtype: DTypeLikeFloat = dtypes.float_
dtype: DTypeLikeFloat = float
):
"""Sample uniformly from the unit Lp ball.
@ -2062,6 +2088,7 @@ def ball(
A random array of shape `(*shape, d)` and specified dtype.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("ball", shape)
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
k1, k2 = split(key)
@ -2073,7 +2100,7 @@ def ball(
def rayleigh(key: KeyArray,
scale: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Rayleigh random values with given shape and float dtype.
The values are returned according to the probability density function:
@ -2099,6 +2126,7 @@ def rayleigh(key: KeyArray,
``shape`` is not None, or else by ``scale.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `rayleigh` must be a float "
f"dtype, got {dtype}")
@ -2125,7 +2153,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
def wald(key: KeyArray,
mean: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Wald random values with given shape and float dtype.
The values are returned according to the probability density function:
@ -2152,6 +2180,7 @@ def wald(key: KeyArray,
``shape`` is not None, or else by ``mean.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `wald` must be a float "
f"dtype, got {dtype}")
@ -2182,7 +2211,7 @@ def _wald(key, mean, shape, dtype) -> Array:
def geometric(key: KeyArray,
p: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeInt = dtypes.int_) -> Array:
dtype: DTypeLikeInt = int) -> Array:
r"""Sample Geometric random values with given shape and float dtype.
The values are returned according to the probability mass function:
@ -2207,6 +2236,7 @@ def geometric(key: KeyArray,
``shape`` is not None, or else by ``p.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.integer):
raise ValueError("dtype argument to `geometric` must be an int "
f"dtype, got {dtype}")
@ -2236,7 +2266,7 @@ def triangular(key: KeyArray,
mode: RealArray,
right: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Triangular random values with given shape and float dtype.
The values are returned according to the probability density function:
@ -2299,7 +2329,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array:
def lognormal(key: KeyArray,
sigma: RealArray = np.float32(1),
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
dtype: DTypeLikeFloat = float) -> Array:
r""" Sample lognormal random values with given shape and float dtype.
The values are distributed according to the probability density function:

View File

@ -1085,7 +1085,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(
lam=[0.5, 3, 9, 11, 50, 500],
dtype=[np.int16, np.int32, np.int64],
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]),
)
def testPoisson(self, lam, dtype):
key = self.make_key(0)
@ -1662,8 +1662,8 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
@jtu.sample_product(
p= [0.2, 0.3, 0.4, 0.5 ,0.6],
dtype= [np.int16, np.int32, np.int64])
p=[0.2, 0.3, 0.4, 0.5 ,0.6],
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]))
def testGeometric(self, p, dtype):
key = self.make_key(1)
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)

View File

@ -112,6 +112,8 @@ class X64ContextTests(jtu.JaxTestCase):
self.assertEqual(x32.result(), jnp.int32)
@jax.legacy_prng_key('allow')
@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype float64 is not available")
def test_jit_cache(self):
if jtu.device_under_test() == "tpu":
self.skipTest("64-bit random not available on TPU")