From f0309b49c92d2f3a69b26e432e77a9a1c124a8e0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 31 Aug 2023 10:56:05 -0700 Subject: [PATCH] jax.random: warn on unsupported dtypes --- jax/_src/random.py | 90 ++++++++++++++++++++++++++------------- tests/random_test.py | 6 +-- tests/x64_context_test.py | 2 + 3 files changed, 65 insertions(+), 33 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 5993ddd94..431fc3072 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -325,7 +325,7 @@ def bits(key: KeyArray, def uniform(key: KeyArray, shape: Union[Shape, NamedShape] = (), - dtype: DTypeLikeFloat = dtypes.float_, + dtype: DTypeLikeFloat = float, minval: RealArray = 0., maxval: RealArray = 1.) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. @@ -343,6 +343,8 @@ def uniform(key: KeyArray, A random array with the specified shape and dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) + if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") @@ -393,7 +395,7 @@ def randint(key: KeyArray, shape: Shape, minval: IntegerArray, maxval: IntegerArray, - dtype: DTypeLikeInt = dtypes.int_) -> Array: + dtype: DTypeLikeInt = int) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -410,6 +412,7 @@ def randint(key: KeyArray, A random array with the specified shape and dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _randint(key, shape, minval, maxval, dtype) @@ -633,7 +636,7 @@ def choice(key: KeyArray, def normal(key: KeyArray, shape: Union[Shape, NamedShape] = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample standard normal random values with given shape and float dtype. The values are returned according to the probability density function: @@ -654,6 +657,7 @@ def normal(key: KeyArray, A random array with the specified shape and dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " f"got {dtype}") @@ -720,6 +724,7 @@ def multivariate_normal(key: KeyArray, ``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) mean, cov = promote_dtypes_inexact(mean, cov) if method not in {'svd', 'eigh', 'cholesky'}: raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}") @@ -769,7 +774,7 @@ def truncated_normal(key: KeyArray, lower: RealArray, upper: RealArray, shape: Optional[Union[Shape, NamedShape]] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample truncated standard normal random values with given shape and dtype. The values are returned according to the probability density function: @@ -798,6 +803,7 @@ def truncated_normal(key: KeyArray, Returns values in the open interval ``(lower, upper)``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `truncated_normal` must be a float " f"dtype, got {dtype}") @@ -879,7 +885,7 @@ def beta(key: KeyArray, a: RealArray, b: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Beta random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -906,6 +912,7 @@ def beta(key: KeyArray, ``shape`` is not None, or else by broadcasting ``a`` and ``b``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `beta` must be a float " f"dtype, got {dtype}") @@ -937,7 +944,7 @@ def _beta(key, a, b, shape, dtype) -> Array: def cauchy(key: KeyArray, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Cauchy random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -958,6 +965,7 @@ def cauchy(key: KeyArray, A random array with the specified shape and dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `cauchy` must be a float " f"dtype, got {dtype}") @@ -976,7 +984,7 @@ def _cauchy(key, shape, dtype) -> Array: def dirichlet(key: KeyArray, alpha: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Dirichlet random values with given shape and float dtype. The values are distributed according the the probability density function: @@ -1009,6 +1017,7 @@ def dirichlet(key: KeyArray, ``alpha.shape``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `dirichlet` must be a float " f"dtype, got {dtype}") @@ -1046,7 +1055,7 @@ def _softmax(x, axis) -> Array: def exponential(key: KeyArray, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Exponential random values with given shape and float dtype. The values are distributed according the the probability density function: @@ -1067,6 +1076,7 @@ def exponential(key: KeyArray, A random array with the specified shape and dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `exponential` must be a float " f"dtype, got {dtype}") @@ -1213,7 +1223,7 @@ batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule def gamma(key: KeyArray, a: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Gamma random values with given shape and float dtype. The values are distributed according the the probability density function: @@ -1247,6 +1257,7 @@ def gamma(key: KeyArray, accuracy for small values of ``a``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " f"dtype, got {dtype}") @@ -1259,7 +1270,7 @@ def gamma(key: KeyArray, def loggamma(key: KeyArray, a: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: """Sample log-gamma random values with given shape and float dtype. This function is implemented such that the following will hold for a @@ -1288,6 +1299,7 @@ def loggamma(key: KeyArray, gamma : standard gamma sampler. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " f"dtype, got {dtype}") @@ -1400,7 +1412,7 @@ def _poisson(key, lam, shape, dtype) -> Array: def poisson(key: KeyArray, lam: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeInt = dtypes.int_) -> Array: + dtype: DTypeLikeInt = int) -> Array: r"""Sample Poisson random values with given shape and integer dtype. The values are distributed according to the probability mass function: @@ -1423,6 +1435,7 @@ def poisson(key: KeyArray, ``shape is not None, or else by ``lam.shape``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) # TODO(frostig): generalize underlying poisson implementation and # remove this check key_impl = key.dtype.impl # type: ignore[union-attr] @@ -1442,7 +1455,7 @@ def poisson(key: KeyArray, def gumbel(key: KeyArray, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: """Sample Gumbel random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1461,6 +1474,7 @@ def gumbel(key: KeyArray, A random array with the specified shape and dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gumbel` must be a float " f"dtype, got {dtype}") @@ -1519,7 +1533,7 @@ def categorical(key: KeyArray, def laplace(key: KeyArray, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Laplace random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1538,6 +1552,7 @@ def laplace(key: KeyArray, A random array with the specified shape and dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `laplace` must be a float " f"dtype, got {dtype}") @@ -1555,7 +1570,7 @@ def _laplace(key, shape, dtype) -> Array: def logistic(key: KeyArray, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample logistic random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1574,6 +1589,7 @@ def logistic(key: KeyArray, A random array with the specified shape and dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `logistic` must be a float " f"dtype, got {dtype}") @@ -1591,7 +1607,7 @@ def _logistic(key, shape, dtype): def pareto(key: KeyArray, b: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Pareto random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1616,6 +1632,7 @@ def pareto(key: KeyArray, ``shape`` is not None, or else by ``b.shape``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `pareto` must be a float " f"dtype, got {dtype}") @@ -1639,7 +1656,7 @@ def _pareto(key, b, shape, dtype) -> Array: def t(key: KeyArray, df: RealArray, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Student's t random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1664,6 +1681,7 @@ def t(key: KeyArray, ``shape`` is not None, or else by ``df.shape``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `t` must be a float " f"dtype, got {dtype}") @@ -1690,7 +1708,7 @@ def _t(key, df, shape, dtype) -> Array: def chisquare(key: KeyArray, df: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Chisquare random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1716,6 +1734,7 @@ def chisquare(key: KeyArray, ``shape`` is not None, or else by ``df.shape``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `chisquare` must be a float " f"dtype, got {dtype}") @@ -1742,7 +1761,7 @@ def f(key: KeyArray, dfnum: RealArray, dfden: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample F-distribution random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1773,6 +1792,7 @@ def f(key: KeyArray, ``shape`` is not None, or else by ``df.shape``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `f` must be a float " f"dtype, got {dtype}") @@ -1803,7 +1823,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array: def rademacher(key: KeyArray, shape: Shape, - dtype: DTypeLikeInt = dtypes.int_) -> Array: + dtype: DTypeLikeInt = int) -> Array: r"""Sample from a Rademacher distribution. The values are distributed according to the probability mass function: @@ -1824,6 +1844,7 @@ def rademacher(key: KeyArray, """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _rademacher(key, shape, dtype) @@ -1837,7 +1858,7 @@ def _rademacher(key, shape, dtype) -> Array: def maxwell(key: KeyArray, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample from a one sided Maxwell distribution. The values are distributed according to the probability density function: @@ -1859,6 +1880,7 @@ def maxwell(key: KeyArray, # Generate samples using: # sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1) key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `maxwell` must be a float " f"dtype, got {dtype}") @@ -1878,7 +1900,7 @@ def double_sided_maxwell(key: KeyArray, loc: RealArray, scale: RealArray, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample from a double sided Maxwell distribution. The values are distributed according to the probability density function: @@ -1901,6 +1923,7 @@ def double_sided_maxwell(key: KeyArray, """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float" f" dtype, got {dtype}") @@ -1929,7 +1952,7 @@ def weibull_min(key: KeyArray, scale: RealArray, concentration: RealArray, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample from a Weibull distribution. The values are distributed according to the probability density function: @@ -1952,6 +1975,7 @@ def weibull_min(key: KeyArray, """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `weibull_min` must be a float " f"dtype, got {dtype}") @@ -1982,7 +2006,7 @@ def orthogonal( key: KeyArray, n: int, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_ + dtype: DTypeLikeFloat = float ) -> Array: """Sample uniformly from the orthogonal group O(n). @@ -1999,6 +2023,7 @@ def orthogonal( A random array of shape `(*shape, n, n)` and specified dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) _check_shape("orthogonal", shape) n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") z = normal(key, (*shape, n, n), dtype) @@ -2010,7 +2035,7 @@ def generalized_normal( key: KeyArray, p: float, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_ + dtype: DTypeLikeFloat = float ) -> Array: r"""Sample from the generalized normal distribution. @@ -2033,6 +2058,7 @@ def generalized_normal( A random array with the specified shape and dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) _check_shape("generalized_normal", shape) keys = split(key) g = gamma(keys[0], 1/p, shape, dtype) @@ -2044,7 +2070,7 @@ def ball( d: int, p: float = 2, shape: Shape = (), - dtype: DTypeLikeFloat = dtypes.float_ + dtype: DTypeLikeFloat = float ): """Sample uniformly from the unit Lp ball. @@ -2062,6 +2088,7 @@ def ball( A random array of shape `(*shape, d)` and specified dtype. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) _check_shape("ball", shape) d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()") k1, k2 = split(key) @@ -2073,7 +2100,7 @@ def ball( def rayleigh(key: KeyArray, scale: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Rayleigh random values with given shape and float dtype. The values are returned according to the probability density function: @@ -2099,6 +2126,7 @@ def rayleigh(key: KeyArray, ``shape`` is not None, or else by ``scale.shape``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `rayleigh` must be a float " f"dtype, got {dtype}") @@ -2125,7 +2153,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array: def wald(key: KeyArray, mean: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Wald random values with given shape and float dtype. The values are returned according to the probability density function: @@ -2152,6 +2180,7 @@ def wald(key: KeyArray, ``shape`` is not None, or else by ``mean.shape``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `wald` must be a float " f"dtype, got {dtype}") @@ -2182,7 +2211,7 @@ def _wald(key, mean, shape, dtype) -> Array: def geometric(key: KeyArray, p: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeInt = dtypes.int_) -> Array: + dtype: DTypeLikeInt = int) -> Array: r"""Sample Geometric random values with given shape and float dtype. The values are returned according to the probability mass function: @@ -2207,6 +2236,7 @@ def geometric(key: KeyArray, ``shape`` is not None, or else by ``p.shape``. """ key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.integer): raise ValueError("dtype argument to `geometric` must be an int " f"dtype, got {dtype}") @@ -2236,7 +2266,7 @@ def triangular(key: KeyArray, mode: RealArray, right: RealArray, shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r"""Sample Triangular random values with given shape and float dtype. The values are returned according to the probability density function: @@ -2299,7 +2329,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array: def lognormal(key: KeyArray, sigma: RealArray = np.float32(1), shape: Optional[Shape] = None, - dtype: DTypeLikeFloat = dtypes.float_) -> Array: + dtype: DTypeLikeFloat = float) -> Array: r""" Sample lognormal random values with given shape and float dtype. The values are distributed according to the probability density function: diff --git a/tests/random_test.py b/tests/random_test.py index 22191e531..de8daae5f 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1085,7 +1085,7 @@ class LaxRandomTest(jtu.JaxTestCase): @jtu.sample_product( lam=[0.5, 3, 9, 11, 50, 500], - dtype=[np.int16, np.int32, np.int64], + dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]), ) def testPoisson(self, lam, dtype): key = self.make_key(0) @@ -1662,8 +1662,8 @@ class LaxRandomTest(jtu.JaxTestCase): self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf) @jtu.sample_product( - p= [0.2, 0.3, 0.4, 0.5 ,0.6], - dtype= [np.int16, np.int32, np.int64]) + p=[0.2, 0.3, 0.4, 0.5 ,0.6], + dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64])) def testGeometric(self, p, dtype): key = self.make_key(1) rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype) diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index b916379bc..ea16d202d 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -112,6 +112,8 @@ class X64ContextTests(jtu.JaxTestCase): self.assertEqual(x32.result(), jnp.int32) @jax.legacy_prng_key('allow') + @jtu.ignore_warning(category=UserWarning, + message="Explicitly requested dtype float64 is not available") def test_jit_cache(self): if jtu.device_under_test() == "tpu": self.skipTest("64-bit random not available on TPU")