mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #14658 from JiaYaobo:chisq_and_f_dist
PiperOrigin-RevId: 513220241
This commit is contained in:
commit
fa1ea37704
@ -21,10 +21,12 @@ List of Available Functions
|
||||
beta
|
||||
categorical
|
||||
cauchy
|
||||
chisquare
|
||||
choice
|
||||
dirichlet
|
||||
double_sided_maxwell
|
||||
exponential
|
||||
f
|
||||
fold_in
|
||||
gamma
|
||||
generalized_normal
|
||||
|
@ -1497,6 +1497,102 @@ def _t(key, df, shape, dtype) -> Array:
|
||||
return n * jnp.sqrt(half_df / g)
|
||||
|
||||
|
||||
def chisquare(key: KeyArray,
|
||||
df: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Chisquare random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
df: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the parameter of the distribution.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
shape. Must be broadcast-compatible with ``df``. The default (None)
|
||||
produces a result shape equal to ``df.shape``.
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
|
||||
Returns:
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``df.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `chisquare` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _chisquare(key, df, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(2, 3), inline=True)
|
||||
def _chisquare(key, df, shape, dtype) -> Array:
|
||||
if shape is None:
|
||||
shape = np.shape(df)
|
||||
else:
|
||||
_check_shape("chisquare", shape, np.shape(df))
|
||||
df = lax.convert_element_type(df, dtype)
|
||||
two = _lax_const(df, 2)
|
||||
half_df = lax.div(df, two)
|
||||
log_g = loggamma(key, a=half_df, shape=shape, dtype=dtype)
|
||||
chi2 = lax.mul(jnp.exp(log_g), two)
|
||||
return chi2
|
||||
|
||||
|
||||
def f(key: KeyArray,
|
||||
dfnum: RealArray,
|
||||
dfden: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample F-distribution random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
dfnum: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the numerator's ``df`` of the distribution.
|
||||
dfden: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the denominator's ``df`` of the distribution.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
shape. Must be broadcast-compatible with ``dfnum`` and ``dfden``.
|
||||
The default (None) produces a result shape equal to ``dfnum.shape``,
|
||||
and ``dfden.shape``.
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
|
||||
Returns:
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``df.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `f` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _f(key, dfnum, dfden, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(3, 4), inline=True)
|
||||
def _f(key, dfnum, dfden, shape, dtype) -> Array:
|
||||
if shape is None:
|
||||
shape = lax.broadcast_shapes(np.shape(dfden), np.shape(dfnum))
|
||||
else:
|
||||
_check_shape("f", shape, np.shape(dfden), np.shape(dfnum))
|
||||
dfden = lax.convert_element_type(dfden, dtype)
|
||||
dfnum = lax.convert_element_type(dfnum, dtype)
|
||||
key_dfd, key_dfn = _split(key)
|
||||
chi2_dfn = chisquare(key_dfn, dfnum, shape, dtype)
|
||||
chi2_dfd = chisquare(key_dfd, dfden, shape, dtype)
|
||||
# broadcast dfden and dfnum to do div operation
|
||||
dfden = jnp.broadcast_to(dfden, shape)
|
||||
dfnum = jnp.broadcast_to(dfnum, shape)
|
||||
num = lax.div(chi2_dfn, dfnum)
|
||||
den = lax.div(chi2_dfd ,dfden)
|
||||
f = lax.div(num, den)
|
||||
return f
|
||||
|
||||
|
||||
def rademacher(key: KeyArray,
|
||||
shape: Shape,
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
||||
|
@ -153,11 +153,13 @@ from jax._src.random import (
|
||||
beta as beta,
|
||||
categorical as categorical,
|
||||
cauchy as cauchy,
|
||||
chisquare as chisquare,
|
||||
choice as choice,
|
||||
default_prng_impl as default_prng_impl,
|
||||
dirichlet as dirichlet,
|
||||
double_sided_maxwell as double_sided_maxwell,
|
||||
exponential as exponential,
|
||||
f as f,
|
||||
fold_in as fold_in,
|
||||
gamma as gamma,
|
||||
generalized_normal as generalized_normal,
|
||||
|
@ -1489,6 +1489,35 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
axis=axis, shape=shape)
|
||||
self.assertEqual(samples.shape, shape)
|
||||
|
||||
@jtu.sample_product(
|
||||
df = [0.2, 1., 10., 100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testChisquare(self, df, dtype):
|
||||
key = self.seed_prng(0)
|
||||
|
||||
rand = lambda key, df: random.chisquare(key, df, shape=(10000, ), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, df)
|
||||
compiled_samples = crand(key, df)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.chi2(df).cdf)
|
||||
|
||||
@jtu.sample_product(
|
||||
dfnum = [1., 2., 10. ,100.],
|
||||
dfden = [1. ,2., 10., 100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testF(self, dfnum, dfden, dtype):
|
||||
key = self.seed_prng(1)
|
||||
rand = lambda key: random.f(key, dfnum, dfden, shape = (10000, ), dtype = dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.f(dfnum, dfden).cdf)
|
||||
|
||||
class KeyArrayTest(jtu.JaxTestCase):
|
||||
# Key arrays involve:
|
||||
|
Loading…
x
Reference in New Issue
Block a user