356 Commits

Author SHA1 Message Date
Jake VanderPlas
df6d84ecf4 remove tests for deprecated jax.random named key constructors
PiperOrigin-RevId: 568314098
2023-09-25 13:46:47 -07:00
Jake VanderPlas
70e0098a87 [random] add itemsize property to custom PRNG 2023-09-25 08:52:26 -07:00
Jake VanderPlas
8125e8bd03 issubdtype: fix corner cases with extended dtypes 2023-09-22 11:37:31 -07:00
Jake VanderPlas
243a6a236c dtypes.issubdtype: validate a when b is dtypes.extended 2023-09-21 15:53:05 -07:00
Jake VanderPlas
22818d664f [random] deprecate named key creation functions 2023-09-21 13:57:49 -07:00
Roy Frostig
2bf9322ccc move wrap_key_data to jax.random
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Roy Frostig
1f8cc44f4e deprecate PRNGKeyArray.unsafe_raw_array in favor of jax.random.key_data
The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.

PiperOrigin-RevId: 565195381
2023-09-13 16:33:56 -07:00
Jake VanderPlas
270cc6014c Update internal callers to avoid PRNGKeyArray 2023-09-13 14:05:42 -07:00
Roy Frostig
6abefa1977 fast dispatch for functions over typed PRNG key arrays
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.

We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.

To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
2023-09-13 09:43:58 -07:00
Jake VanderPlas
f0309b49c9 jax.random: warn on unsupported dtypes 2023-08-31 10:56:05 -07:00
Jake VanderPlas
4b89d03147 Deprecate the contents of jax.prng 2023-08-30 15:13:32 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Jake VanderPlas
630a69f41b [random] add jax_legacy_prng_key flag 2023-08-22 15:08:51 -07:00
jiayaobo
6d184459ef add random.triangular and random.lognormal
add random.triangular and random.lognormal

add random.triangular and random.lognormal
2023-08-05 12:27:53 +08:00
Jake VanderPlas
7d7a536b55 custom prng: introduce mechanism to identify key arrays by dtype 2023-07-21 12:27:32 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
jax authors
c006e52f1a Merge pull request #16779 from jakevdp:random-gamma
PiperOrigin-RevId: 549460066
2023-07-19 16:39:05 -07:00
Jake VanderPlas
7205160095 Re-parameterize jax.random.gamma for better behavior at endpoints 2023-07-19 16:15:03 -07:00
Lukas Geiger
de2c8541be Remove obsolete numpy version checks 2023-07-19 23:33:47 +01:00
Roy Frostig
9aa5307e2f API compatibility policy: expand on numerics and randomness 2023-07-18 14:13:25 -07:00
Roy Frostig
df2891ff13 accept general shape option in jax.random.split
Several PRNG implementations (notably partitionable threefry) support
splitting to arbitrary shapes, rather than only to a 1-D vector of
keys. This change:

* Upgrades `jax.random.split` to accept a general shape as an
  argument.
* Updates the internal PRNG interface, and our various PRNG
  implementations, to accept and handle such a shape argument.

This change keeps the argument name `num`. We can still think on
whether and how we'd like to upgrade to `shape`.

Note that we could have supported arbitrary shapes by reduction to the
previous API (with a flat split count), using reshapes. We'd like to
avoid that, so as not to hide this structure from the underlying
implementation. For instance, partitionable threefry hashes a *shaped*
iota in order to split keys, and we don't want to flatten and reshape
around that for no reason.

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-07-11 14:24:01 -07:00
Roy Frostig
ce9c2d650a rename seed_prng test method to make_key 2023-07-05 15:26:30 -07:00
Roy Frostig
ff70255af9 consistently seed keys indirectly by test class method in LaxRandomTest 2023-07-05 15:18:54 -07:00
Roy Frostig
556c1123cf parameterize two random tests over key constructors 2023-07-05 15:18:54 -07:00
Roy Frostig
c710c7578d move and remove code in random_test 2023-07-05 15:18:54 -07:00
jax authors
7c7051a4cc Merge pull request #16607 from froystig:random-test-double-threefry
PiperOrigin-RevId: 545799083
2023-07-05 15:15:35 -07:00
Roy Frostig
30542bd5bd match behavior of double-threefry test RNG and standard threefry RNG
This also lets us avoid a guard on `config.jax_enable_custom_prng` in
random tests.
2023-07-05 15:01:12 -07:00
Roy Frostig
09af6b1e01 test non-threefry RNGs across both typed and raw key formats
This also lets us remove some test guards on `config.jax_enable_custom_prng`.
2023-07-05 13:54:14 -07:00
Roy Frostig
bc44b99d05 avoid raw key arrays in typed key sharding test
This also lets us remove a guard on `config.jax_enable_custom_prng` in
random tests.
2023-06-30 20:38:26 -07:00
Roy Frostig
f8dee51d9a increase random test coverage over RNG key constructors and representations
This is an incremental change to our random tests that primarily:

* Increases test coverage of both key constructors (`random.key` and
  `random.PRNGKey`), often by parameterizing tests over both.

* Increases test coverage of both key representations (typed key
  arrays and `uint32` arrays).

* Removes a handful of guards on `config.jax_enable_custom_prng`,
  either replacing them with `isinstance` checks for typed keys or
  removing them altogether if possible.

* Makes a handful of other individual test improvements and fixes, and
  leaves comments for more.
2023-06-30 20:26:29 -07:00
Roy Frostig
9b346861a9 add impl option to random key constructors that picks the RNG implementation
This change primarily adds an optional argument to both old- and
new-style random key constructors. The option determines the PRNG
implementation for the key by name, overriding any default
implementation determined by configuration flags.

Along the way, looking ahead:

* We can deprecate the (anyway underused) individual explicit key
  constructors like `jax.random.threefr2x32_key` in favor of this
  option.

* Some day, instead of only accepting RNG implementations by name
  (string), we can also accept the output of some custom PRNG
  implementation API that we expose, maybe via `jax.extend.random`
  (corresponding roughly to the current `_src.prng.PRNGImpl`).
2023-06-30 16:42:22 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
39645b5c20 Custom PRNG: improve test coverage when enable_custom_prng=false
We're now moving to a world where custom PRNG should exist side-by-side with the old PRNG
implementation. This change improves test coverage for that, by enabling relevant tests
even when the flag is set to False.
2023-06-23 00:29:11 -07:00
Jake VanderPlas
951d515701 random.key: error for non-scalar seeds.
Previously, this function's implementation would implicitly map over non-scalar
seed inputs. This is not the behavior we want, because in the future we may want
to allow arrays of integers as a single seed.
2023-06-20 01:16:25 -07:00
Jake VanderPlas
47ce94fa0d Custom PRNG: fix incorrect assertion in select lowering rule 2023-06-15 02:43:43 -07:00
Jake VanderPlas
3a7ccf70f2 custom prng: add shard methods to PRNGKeyArrayImpl 2023-06-01 04:10:12 -07:00
Jake VanderPlas
48abe7c684 PRNGKeyArray: add several missing attributes & methods 2023-05-17 14:47:22 -07:00
Jake VanderPlas
6ef4e5f01a Custom PRNG: make KeyArray compatible with custom_jvp 2023-05-17 10:31:09 -07:00
Jake VanderPlas
0e483223c6 Custom PRNG: support lax.full() and related constructors 2023-05-17 09:04:50 -07:00
Jake VanderPlas
1b00abf819 custom PRNG: better error when using key as seed 2023-05-16 13:47:19 -07:00
Jake VanderPlas
5e14744c2c jax.random.choice: make return dtype consistent 2023-05-16 08:52:11 -07:00
Jake VanderPlas
b9aa236dac Custom PRNG: support PRNGKeyArray.copy() 2023-05-12 15:50:22 -07:00
Jake VanderPlas
b250c706b0 Allow opaque dtypes in grad with allow_int=True 2023-05-10 11:43:17 -07:00
Jake VanderPlas
6ada8785aa PRNGKeyArray: fix dynamic slice index dtype 2023-05-10 09:24:18 -07:00
Roy Frostig
051c5dda6e delegate select lowering to opaque dtype rule
... and implement it for PRNG key arrays
2023-05-08 19:02:42 -07:00
Jake VanderPlas
4db717c52a KeyArray: support make_array_from_* APIs 2023-05-04 16:32:49 -07:00
jax authors
5d143e6eea Merge pull request #15818 from froystig:random-bits-direct
PiperOrigin-RevId: 529090390
2023-05-03 07:56:17 -07:00
Roy Frostig
ea3389205f add jax.random.bits 2023-05-03 06:10:05 -07:00
Jake VanderPlas
979aa3235b KeyArray: implement sharded & replicated device_put 2023-05-01 14:17:01 -07:00
Jake VanderPlas
054fca5cd4 KeyArray: define itemsize on opaque dtype 2023-04-27 15:59:57 -07:00