1. rbg is not identical across cpu/gpu/tpu;
2. the unsafe_rbg column copied the jax.lax.rng_uniform column from the original table, but that wasnt right, as it should be identical to the rbg column;
3. for the last row mentioning identical across shardings, we should mention that's assuming the xla flag
Also removed some rows which are only interesting in comparing to `jax.lax.rng_uniform` (which is not safe with `scan` or `remat`).
Co-authored-by: Roy Frostig <frostig@google.com>
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`
PiperOrigin-RevId: 469276609
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.
We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.
This leans heavily on our recently-introduced extended element type
capabilities.
As a consequence, `vmap`, `scan`, etc. now work.
A sample of changes made to introduce key-element-type arrays:
* Introduce a new element type (`prng.KeyTy`), with the requisite IR
type mapping and device result handlers, as well as lowering rules
for dtype-polymorphic primitive operations.
* Introduce primitives for basic RNG operations: `random_seed`,
`random_bits`, `random_split`, `random_fold_in`. These primitives
essentially delegate to the underlying PRNG implementation (directly
so in their impl rules, and by translating their staged-out form in
lowering rules).
* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
conversion from/to the base `uint32` array. We need this backwards
compatibility, and it's useful for tests.
* Introduce some `vmap`-based helpers to adapt PRNG impls (which
define basic `random_bits`, `split`, etc. on scalars) to the above
batch-polymorphic primitives. Most of the primitives are vectorized,
but `random_fold_in` is a broadcasting binary op.
* Update the `gamma` primitive rules to account for key-element-type
abstract arrays (nice simplification here).
* Give PRNG implementation short string names ("tags") for IR
pretty-printing.
* Update `lax.stop_gradient` to handle opaque dtypes.
* Fix up loop MLIR lowering, which assumed that shaped arrays of all
dtypes have the same physical shape.
* Add new tests (exercising staging, jaxprs, lowerings, ...)
A sample of changes made to rework Python-level PRNG key arrays:
* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
tracers that carry them.
* Patch (only a subset of) standard device array attributes onto PRNG
key arrays.
* Implement various conversion handlers (sharding, constant-creation,
`device_put`).
* Accept PRNG key arrays as input to `lax_numpy.transpose`.
* Update tests and rename some internals.
A sample of extra changes along the way:
* Disallow AD on key-typed arrays in the main API.
* Hoist `random_bits`'s named-shape-handling logic, which used to only
take place in the threefry PRNG's `random_bits` implementation, up
to the new `random_bits` traceable, so that we apply it consistently
across PRNG implementations.
This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.
Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.
A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:
1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
back to it from the functions in random.
1. Build on Windows
2. Fix OverflowError
When calling `key = random.PRNGKey(0)` OverflowError: Python int too
large to convert to C long for casting value 4294967295 (0xFFFFFFFF)
from python int to int32.
3. fix file path in regex of errors_test
4. handle ValueError of os.path.commonpath
A user observed -inf values being returned by truncated_normal(), which occur if the uniform random value passed to erfinv() is out of range, e.g., due to rounding. Do more of the computation using jax.random.uniform(), which promises correct behavior in the face of rounding.
As an added security measure, also clamp the outputs of the function to the open interval.
Since we do the threefry with signed integers when converting to TF,
we run into the type promotion 'uint32 - int32 = int64', which
then results in lax.shift_right_logical(uint32, int64), which fails.
* [jax2tf] implementation of random_gamma
The simplest implementation is by converting the JAX own impl_rule,
which rewrites gamma into other JAX primitives.
On TPU with use_vmap=True the performance is the same for JAX and TF, provided
we use tf.function(compile=True).
* allow random.choice to accept ndarray `a`
follow-up to #4137 to allow ndarray inputs to be passed
* add jax.random.choice tests to cover ndarray input
* don't use callables in test params
it can mess with pytest-xdist because of hashing by id