A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.
A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:
1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
back to it from the functions in random.
1. Build on Windows
2. Fix OverflowError
When calling `key = random.PRNGKey(0)` OverflowError: Python int too
large to convert to C long for casting value 4294967295 (0xFFFFFFFF)
from python int to int32.
3. fix file path in regex of errors_test
4. handle ValueError of os.path.commonpath
A user observed -inf values being returned by truncated_normal(), which occur if the uniform random value passed to erfinv() is out of range, e.g., due to rounding. Do more of the computation using jax.random.uniform(), which promises correct behavior in the face of rounding.
As an added security measure, also clamp the outputs of the function to the open interval.
Since we do the threefry with signed integers when converting to TF,
we run into the type promotion 'uint32 - int32 = int64', which
then results in lax.shift_right_logical(uint32, int64), which fails.
* [jax2tf] implementation of random_gamma
The simplest implementation is by converting the JAX own impl_rule,
which rewrites gamma into other JAX primitives.
On TPU with use_vmap=True the performance is the same for JAX and TF, provided
we use tf.function(compile=True).
* allow random.choice to accept ndarray `a`
follow-up to #4137 to allow ndarray inputs to be passed
* add jax.random.choice tests to cover ndarray input
* don't use callables in test params
it can mess with pytest-xdist because of hashing by id
* Fix bug #3997, change `jax.random.multivariate_normal` to handle batches of covariance matrices. It works as long as mean and covariance are broadcast-compatible, as specified in the docstring.
* Fix bug in multivariate_normal shape checking
Minor bug: should be checking for compatibility of `shape`, `mean`, and the the last two dimensions of the _covariance_ matrix.
* Add test for multivariate_normal shapes
This test checks that `jax.random.multivariate_normal` produces the expected output shape for various combinations of event dimension and `mean`, `covariance`, and `shape` shapes.
* Fix linter issues in tests/random_test.py
Trimming trialing whitespace and 80 char limit.
* Really trimming whitespace in tests/random_test.py
Arg. Have to fix my editor to do this automatically.
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.