209 Commits

Author SHA1 Message Date
Jake VanderPlas
5f7335fb55 Deprecate jax.random.shuffle
This has been long deprecated, but this PR uses the standard deprecation
framework to make it easier to finalize.
2023-11-06 12:21:56 -08:00
Jake VanderPlas
a4e6b4e943 [random] add more information to KeyArray deprecation error 2023-10-31 16:54:49 -07:00
Jake VanderPlas
53c4de477e [random] deprecate jax.random.default_prng_impl() 2023-10-19 13:59:01 -07:00
Jake VanderPlas
06306274e5 Fix type checking declaration of jax.random.threefry2x32_p
Followup to https://github.com/google/jax/pull/18176

PiperOrigin-RevId: 574891218
2023-10-19 09:05:50 -07:00
Jake VanderPlas
b865827d06 [random] deprecate jax.random.threefry_2x32 & threefry2x32_p 2023-10-18 14:42:49 -07:00
Roy Frostig
5158e251b6 identify PRNG schemes on key arrays, and recognize them in key constructors
Specifically:

* Introduce `jax.random.key_impl`, which accepts a key array and
  returns a hashable identifier of its PRNG implementation.

* Accept this identifier optionally as the `impl` argument to
  `jax.random.key` and `wrap_key_data`.

This now works:

```python
k1 = jax.random.key(72, impl='threefry2x32')
impl = jax.random.key_impl(k1)
k2 = jax.random.key(72, impl=impl)
assert arrays_equal(k1, k2)
assert k1.dtype == k2.dtype
```

This change also set up an internal PRNG registry and register
built-in implementations, to simplify various places where we
essentially reconstruct such a registry from scratch (such as in
tests).

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-10-06 10:15:08 -07:00
Jake VanderPlas
1d8f2ac08f fix typo
PiperOrigin-RevId: 567507063
2023-09-21 21:31:19 -07:00
Jake VanderPlas
243a6a236c dtypes.issubdtype: validate a when b is dtypes.extended 2023-09-21 15:53:05 -07:00
Jake VanderPlas
22818d664f [random] deprecate named key creation functions 2023-09-21 13:57:49 -07:00
Roy Frostig
2bf9322ccc move wrap_key_data to jax.random
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Jake VanderPlas
4e6c1b68c7 Deprecate random.KeyArray and random.PRNGKeyArray 2023-09-13 14:05:42 -07:00
jiayaobo
6d184459ef add random.triangular and random.lognormal
add random.triangular and random.lognormal

add random.triangular and random.lognormal
2023-08-05 12:27:53 +08:00
Roy Frostig
b7b90e62e3 add new random key constructor
This constructor unconditionally returns a typed key array, regardless
of the value of `jax.config.enable_custom_prng`. We can switch to
referring to it in randomness docs and tutorials as we complete the
typed key upgrade.
2023-05-22 11:35:10 -07:00
Roy Frostig
ea3389205f add jax.random.bits 2023-05-03 06:10:05 -07:00
Matthew Johnson
a7f5e07549 update prng docs to mention jax_threefry_partitionable
fixes #15484
2023-04-07 22:55:15 -07:00
jiayaobo
924894e85c add geometric random gen
add geom random

add geom random

add geom random

add geom random
2023-03-30 02:08:04 +08:00
jiayaobo
f7a14d65d2 add wald random generator
add wald to random.py
2023-03-22 11:06:59 +08:00
jiayaobo
05c47033b2 add rayleigh distribution to random.py
add rayleigh distribution to random.py

add rayleigh distribution to random.py

add rayleigh to random.py
2023-03-14 09:55:54 +08:00
jiayaobo
fdf8ac18d6 add random.chisquare and random.f
add chi2 and F random variables methods

add chi2 and F random variables methods

fix F rv shape broadcasting

fix shape broadcasting
2023-03-01 15:03:50 +08:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Matthew Johnson
478bd3ea4e fix comparison table in random docs
1. rbg is not identical across cpu/gpu/tpu;
2. the unsafe_rbg column copied the jax.lax.rng_uniform column from the original table, but that wasnt right, as it should be identical to the rbg column;
3. for the last row mentioning identical across shardings, we should mention that's assuming the xla flag

Also removed some rows which are only interesting in comparing to `jax.lax.rng_uniform` (which is not safe with `scan` or `remat`).

Co-authored-by: Roy Frostig <frostig@google.com>
2022-11-03 13:10:54 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Roy Frostig
0d3630b349 add key_data to jax.random for key array unwrapping
This is often useful in testing and debugging. Its more dangerous
inverse, wrapping, remains internal only.
2022-08-31 09:23:11 -07:00
Roy Frostig
6071a8f875 roll-forward #11952, take 2
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`

PiperOrigin-RevId: 469276609
2022-08-22 13:57:31 -07:00
jax authors
3a2f25ff31 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468840334
2022-08-19 21:02:18 -07:00
Roy Frostig
9789e83b26 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468835674
2022-08-19 20:12:32 -07:00
jax authors
a6c6416872 Internal change
PiperOrigin-RevId: 468712508
2022-08-19 08:56:49 -07:00
Roy Frostig
7f06df1ea1 introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.

We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.

This leans heavily on our recently-introduced extended element type
capabilities.

As a consequence, `vmap`, `scan`, etc. now work.

A sample of changes made to introduce key-element-type arrays:

* Introduce a new element type (`prng.KeyTy`), with the requisite IR
  type mapping and device result handlers, as well as lowering rules
  for dtype-polymorphic primitive operations.

* Introduce primitives for basic RNG operations: `random_seed`,
  `random_bits`, `random_split`, `random_fold_in`. These primitives
  essentially delegate to the underlying PRNG implementation (directly
  so in their impl rules, and by translating their staged-out form in
  lowering rules).

* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
  conversion from/to the base `uint32` array. We need this backwards
  compatibility, and it's useful for tests.

* Introduce some `vmap`-based helpers to adapt PRNG impls (which
  define basic `random_bits`, `split`, etc. on scalars) to the above
  batch-polymorphic primitives. Most of the primitives are vectorized,
  but `random_fold_in` is a broadcasting binary op.

* Update the `gamma` primitive rules to account for key-element-type
  abstract arrays (nice simplification here).

* Give PRNG implementation short string names ("tags") for IR
  pretty-printing.

* Update `lax.stop_gradient` to handle opaque dtypes.

* Fix up loop MLIR lowering, which assumed that shaped arrays of all
  dtypes have the same physical shape.

* Add new tests (exercising staging, jaxprs, lowerings, ...)

A sample of changes made to rework Python-level PRNG key arrays:

* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
  tracers that carry them.

* Patch (only a subset of) standard device array attributes onto PRNG
  key arrays.

* Implement various conversion handlers (sharding, constant-creation,
  `device_put`).

* Accept PRNG key arrays as input to `lax_numpy.transpose`.

* Update tests and rename some internals.

A sample of extra changes along the way:

* Disallow AD on key-typed arrays in the main API.

* Hoist `random_bits`'s named-shape-handling logic, which used to only
  take place in the threefry PRNG's `random_bits` implementation, up
  to the new `random_bits` traceable, so that we apply it consistently
  across PRNG implementations.

This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.

Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-18 21:46:55 -07:00
Peter Hawkins
71b29b1cc6 Create JAX Enhancement Proposals (JEPs).
Migrate existing design documents to JEPs.
2022-08-08 16:13:58 -04:00
James Bradbury
64eb46a172
Fix RST formatting in random.py docstring 2022-07-11 02:51:35 -07:00
Tamara Norman
bc9c4b77d0 Adjust docs to account for what the actual current RNG behavior is
PiperOrigin-RevId: 459712928
2022-07-08 02:55:36 -07:00
carlosgmartin
ca83a80f95 Added random.generalized_normal and random.ball. 2022-06-03 15:11:29 -04:00
Carlos Martin
b276c31b75 Added random.orthogonal. 2022-04-29 14:20:50 -04:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Jeppe Klitgaard
342923335a fix: explicit reexport
Explicitly reexports PRNGKeyArray as KeyArray in accordance with PEP 484

See also: https://github.com/python/mypy/issues/11706
2022-04-19 17:54:32 +01:00
Jean-Baptiste
46a666c448 Improve the random module documentation. 2022-04-15 10:49:52 +02:00
Jake VanderPlas
69969ef803 add random.loggamma and improve dirichlet & beta implementation 2022-03-21 08:33:11 -07:00
Roy Frostig
026b91b85d add random.default_prng_impl to retrieve the default PRNG implementation 2022-01-12 19:13:14 -08:00
Roy Frostig
98d245ebb4 add a config setting to control the default PRNG implementation
Also add explicit seeding functions for each PRNG implementation.
2021-10-07 21:22:40 -07:00
jax authors
47b40e001d Merge pull request #7847 from google:prng-type
PiperOrigin-RevId: 396661434
2021-09-14 12:55:28 -07:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Roy Frostig
9134794ddf expose random key array type
This is meant to allow for downstream type annotations and checking
(e.g. via `isinstance`).
2021-09-07 16:30:13 -07:00
Roy Frostig
4eb437a568 alias prng.threefry_2x32 in random and warn of move
Some call this, apparently.
2021-08-19 20:43:11 -07:00
Roy Frostig
aa265cce95 introduce custom PRNG implementations and an array-like adapter for them
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.

A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:

1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
   back to it from the functions in random.
2021-08-19 20:43:11 -07:00
Roy Frostig
b4ccecca88 factor PRNG routines from random module to prng 2021-08-17 19:27:31 -07:00
George Necula
6a48c60a72 Rename master to main in embedded links.
Tried to avoid the change on external links to repos that
have not yet renamed master.
2021-06-18 10:00:01 +03:00
Jake VanderPlas
8e789c7380 Run doctest on all source files except jax2tf 2021-04-05 10:39:59 -07:00
Jake VanderPlas
4ae393f8a8 DOC: fix jax.random documentation 2020-11-20 16:41:32 -08:00