The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.
PiperOrigin-RevId: 565195381
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.
We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.
To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
Several PRNG implementations (notably partitionable threefry) support
splitting to arbitrary shapes, rather than only to a 1-D vector of
keys. This change:
* Upgrades `jax.random.split` to accept a general shape as an
argument.
* Updates the internal PRNG interface, and our various PRNG
implementations, to accept and handle such a shape argument.
This change keeps the argument name `num`. We can still think on
whether and how we'd like to upgrade to `shape`.
Note that we could have supported arbitrary shapes by reduction to the
previous API (with a flat split count), using reshapes. We'd like to
avoid that, so as not to hide this structure from the underlying
implementation. For instance, partitionable threefry hashes a *shaped*
iota in order to split keys, and we don't want to flatten and reshape
around that for no reason.
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
This is an incremental change to our random tests that primarily:
* Increases test coverage of both key constructors (`random.key` and
`random.PRNGKey`), often by parameterizing tests over both.
* Increases test coverage of both key representations (typed key
arrays and `uint32` arrays).
* Removes a handful of guards on `config.jax_enable_custom_prng`,
either replacing them with `isinstance` checks for typed keys or
removing them altogether if possible.
* Makes a handful of other individual test improvements and fixes, and
leaves comments for more.
This change primarily adds an optional argument to both old- and
new-style random key constructors. The option determines the PRNG
implementation for the key by name, overriding any default
implementation determined by configuration flags.
Along the way, looking ahead:
* We can deprecate the (anyway underused) individual explicit key
constructors like `jax.random.threefr2x32_key` in favor of this
option.
* Some day, instead of only accepting RNG implementations by name
(string), we can also accept the output of some custom PRNG
implementation API that we expose, maybe via `jax.extend.random`
(corresponding roughly to the current `_src.prng.PRNGImpl`).
We're now moving to a world where custom PRNG should exist side-by-side with the old PRNG
implementation. This change improves test coverage for that, by enabling relevant tests
even when the flag is set to False.
Previously, this function's implementation would implicitly map over non-scalar
seed inputs. This is not the behavior we want, because in the future we may want
to allow arrays of integers as a single seed.