Merge pull request #14658 from JiaYaobo:chisq_and_f_dist

PiperOrigin-RevId: 513220241
This commit is contained in:
jax authors 2023-03-01 06:35:34 -08:00
commit fa1ea37704
4 changed files with 129 additions and 0 deletions

View File

@ -21,10 +21,12 @@ List of Available Functions
beta
categorical
cauchy
chisquare
choice
dirichlet
double_sided_maxwell
exponential
f
fold_in
gamma
generalized_normal

View File

@ -1497,6 +1497,102 @@ def _t(key, df, shape, dtype) -> Array:
return n * jnp.sqrt(half_df / g)
def chisquare(key: KeyArray,
df: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
"""Sample Chisquare random values with given shape and float dtype.
Args:
key: a PRNG key used as the random key.
df: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``df``. The default (None)
produces a result shape equal to ``df.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key(key)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `chisquare` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _chisquare(key, df, shape, dtype)
@partial(jit, static_argnums=(2, 3), inline=True)
def _chisquare(key, df, shape, dtype) -> Array:
if shape is None:
shape = np.shape(df)
else:
_check_shape("chisquare", shape, np.shape(df))
df = lax.convert_element_type(df, dtype)
two = _lax_const(df, 2)
half_df = lax.div(df, two)
log_g = loggamma(key, a=half_df, shape=shape, dtype=dtype)
chi2 = lax.mul(jnp.exp(log_g), two)
return chi2
def f(key: KeyArray,
dfnum: RealArray,
dfden: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
"""Sample F-distribution random values with given shape and float dtype.
Args:
key: a PRNG key used as the random key.
dfnum: a float or array of floats broadcast-compatible with ``shape``
representing the numerator's ``df`` of the distribution.
dfden: a float or array of floats broadcast-compatible with ``shape``
representing the denominator's ``df`` of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``dfnum`` and ``dfden``.
The default (None) produces a result shape equal to ``dfnum.shape``,
and ``dfden.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key(key)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `f` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _f(key, dfnum, dfden, shape, dtype)
@partial(jit, static_argnums=(3, 4), inline=True)
def _f(key, dfnum, dfden, shape, dtype) -> Array:
if shape is None:
shape = lax.broadcast_shapes(np.shape(dfden), np.shape(dfnum))
else:
_check_shape("f", shape, np.shape(dfden), np.shape(dfnum))
dfden = lax.convert_element_type(dfden, dtype)
dfnum = lax.convert_element_type(dfnum, dtype)
key_dfd, key_dfn = _split(key)
chi2_dfn = chisquare(key_dfn, dfnum, shape, dtype)
chi2_dfd = chisquare(key_dfd, dfden, shape, dtype)
# broadcast dfden and dfnum to do div operation
dfden = jnp.broadcast_to(dfden, shape)
dfnum = jnp.broadcast_to(dfnum, shape)
num = lax.div(chi2_dfn, dfnum)
den = lax.div(chi2_dfd ,dfden)
f = lax.div(num, den)
return f
def rademacher(key: KeyArray,
shape: Shape,
dtype: DTypeLikeInt = dtypes.int_) -> Array:

View File

@ -153,11 +153,13 @@ from jax._src.random import (
beta as beta,
categorical as categorical,
cauchy as cauchy,
chisquare as chisquare,
choice as choice,
default_prng_impl as default_prng_impl,
dirichlet as dirichlet,
double_sided_maxwell as double_sided_maxwell,
exponential as exponential,
f as f,
fold_in as fold_in,
gamma as gamma,
generalized_normal as generalized_normal,

View File

@ -1489,6 +1489,35 @@ class LaxRandomTest(jtu.JaxTestCase):
axis=axis, shape=shape)
self.assertEqual(samples.shape, shape)
@jtu.sample_product(
df = [0.2, 1., 10., 100.],
dtype=jtu.dtypes.floating)
def testChisquare(self, df, dtype):
key = self.seed_prng(0)
rand = lambda key, df: random.chisquare(key, df, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, df)
compiled_samples = crand(key, df)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.chi2(df).cdf)
@jtu.sample_product(
dfnum = [1., 2., 10. ,100.],
dfden = [1. ,2., 10., 100.],
dtype=jtu.dtypes.floating)
def testF(self, dfnum, dfden, dtype):
key = self.seed_prng(1)
rand = lambda key: random.f(key, dfnum, dfden, shape = (10000, ), dtype = dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.f(dfnum, dfden).cdf)
class KeyArrayTest(jtu.JaxTestCase):
# Key arrays involve: