221 Commits

Author SHA1 Message Date
Jake VanderPlas
a17c8d945b Finalize deprecation of jax.random.shuffle
This has been raising a DeprecationWarning for longer than anyone can remember.

PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
Roy Frostig
f30ebd8586 document vmap peculiarity of experimental RNG implementations 2024-07-26 13:40:16 -07:00
Roy Frostig
6ddd488df0 improve RNG doc around implementation configuration 2024-07-26 13:40:16 -07:00
jax authors
0302e4c34d Merge pull request #17741 from froystig:new-style-key-docs
PiperOrigin-RevId: 614080080
2024-03-08 16:41:22 -08:00
jax authors
0f2a89d837 Merge pull request #20139 from jakevdp:random-clone
PiperOrigin-RevId: 613979093
2024-03-08 10:42:16 -08:00
Jake VanderPlas
6771a59181 [key reuse] add jax.random.clone 2024-03-08 09:06:00 -08:00
yixiaoer
6ada248b3c update 2024-03-09 00:38:28 +08:00
Roy Frostig
721ca3f714 add key array upgrade note to jax.random module doc 2024-03-07 12:56:42 -08:00
Roy Frostig
98f790f5d5 update package/API reference docs to new-style typed PRNG keys 2024-03-07 12:40:09 -08:00
Jake VanderPlas
9b9aa1efaf Finalize a number of deprecations from JAX 0.4.19
PiperOrigin-RevId: 600509530
2024-01-22 11:13:25 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
jiayaobo
ae2387dc27 add random.binomial
update

update

modify
2023-11-19 14:51:10 +08:00
Jake VanderPlas
5f7335fb55 Deprecate jax.random.shuffle
This has been long deprecated, but this PR uses the standard deprecation
framework to make it easier to finalize.
2023-11-06 12:21:56 -08:00
Jake VanderPlas
a4e6b4e943 [random] add more information to KeyArray deprecation error 2023-10-31 16:54:49 -07:00
Jake VanderPlas
53c4de477e [random] deprecate jax.random.default_prng_impl() 2023-10-19 13:59:01 -07:00
Jake VanderPlas
06306274e5 Fix type checking declaration of jax.random.threefry2x32_p
Followup to https://github.com/google/jax/pull/18176

PiperOrigin-RevId: 574891218
2023-10-19 09:05:50 -07:00
Jake VanderPlas
b865827d06 [random] deprecate jax.random.threefry_2x32 & threefry2x32_p 2023-10-18 14:42:49 -07:00
Roy Frostig
5158e251b6 identify PRNG schemes on key arrays, and recognize them in key constructors
Specifically:

* Introduce `jax.random.key_impl`, which accepts a key array and
  returns a hashable identifier of its PRNG implementation.

* Accept this identifier optionally as the `impl` argument to
  `jax.random.key` and `wrap_key_data`.

This now works:

```python
k1 = jax.random.key(72, impl='threefry2x32')
impl = jax.random.key_impl(k1)
k2 = jax.random.key(72, impl=impl)
assert arrays_equal(k1, k2)
assert k1.dtype == k2.dtype
```

This change also set up an internal PRNG registry and register
built-in implementations, to simplify various places where we
essentially reconstruct such a registry from scratch (such as in
tests).

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-10-06 10:15:08 -07:00
Jake VanderPlas
1d8f2ac08f fix typo
PiperOrigin-RevId: 567507063
2023-09-21 21:31:19 -07:00
Jake VanderPlas
243a6a236c dtypes.issubdtype: validate a when b is dtypes.extended 2023-09-21 15:53:05 -07:00
Jake VanderPlas
22818d664f [random] deprecate named key creation functions 2023-09-21 13:57:49 -07:00
Roy Frostig
2bf9322ccc move wrap_key_data to jax.random
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Jake VanderPlas
4e6c1b68c7 Deprecate random.KeyArray and random.PRNGKeyArray 2023-09-13 14:05:42 -07:00
jiayaobo
6d184459ef add random.triangular and random.lognormal
add random.triangular and random.lognormal

add random.triangular and random.lognormal
2023-08-05 12:27:53 +08:00
Roy Frostig
b7b90e62e3 add new random key constructor
This constructor unconditionally returns a typed key array, regardless
of the value of `jax.config.enable_custom_prng`. We can switch to
referring to it in randomness docs and tutorials as we complete the
typed key upgrade.
2023-05-22 11:35:10 -07:00
Roy Frostig
ea3389205f add jax.random.bits 2023-05-03 06:10:05 -07:00
Matthew Johnson
a7f5e07549 update prng docs to mention jax_threefry_partitionable
fixes #15484
2023-04-07 22:55:15 -07:00
jiayaobo
924894e85c add geometric random gen
add geom random

add geom random

add geom random

add geom random
2023-03-30 02:08:04 +08:00
jiayaobo
f7a14d65d2 add wald random generator
add wald to random.py
2023-03-22 11:06:59 +08:00
jiayaobo
05c47033b2 add rayleigh distribution to random.py
add rayleigh distribution to random.py

add rayleigh distribution to random.py

add rayleigh to random.py
2023-03-14 09:55:54 +08:00
jiayaobo
fdf8ac18d6 add random.chisquare and random.f
add chi2 and F random variables methods

add chi2 and F random variables methods

fix F rv shape broadcasting

fix shape broadcasting
2023-03-01 15:03:50 +08:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Matthew Johnson
478bd3ea4e fix comparison table in random docs
1. rbg is not identical across cpu/gpu/tpu;
2. the unsafe_rbg column copied the jax.lax.rng_uniform column from the original table, but that wasnt right, as it should be identical to the rbg column;
3. for the last row mentioning identical across shardings, we should mention that's assuming the xla flag

Also removed some rows which are only interesting in comparing to `jax.lax.rng_uniform` (which is not safe with `scan` or `remat`).

Co-authored-by: Roy Frostig <frostig@google.com>
2022-11-03 13:10:54 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Roy Frostig
0d3630b349 add key_data to jax.random for key array unwrapping
This is often useful in testing and debugging. Its more dangerous
inverse, wrapping, remains internal only.
2022-08-31 09:23:11 -07:00
Roy Frostig
6071a8f875 roll-forward #11952, take 2
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`

PiperOrigin-RevId: 469276609
2022-08-22 13:57:31 -07:00
jax authors
3a2f25ff31 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468840334
2022-08-19 21:02:18 -07:00
Roy Frostig
9789e83b26 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468835674
2022-08-19 20:12:32 -07:00
jax authors
a6c6416872 Internal change
PiperOrigin-RevId: 468712508
2022-08-19 08:56:49 -07:00
Roy Frostig
7f06df1ea1 introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.

We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.

This leans heavily on our recently-introduced extended element type
capabilities.

As a consequence, `vmap`, `scan`, etc. now work.

A sample of changes made to introduce key-element-type arrays:

* Introduce a new element type (`prng.KeyTy`), with the requisite IR
  type mapping and device result handlers, as well as lowering rules
  for dtype-polymorphic primitive operations.

* Introduce primitives for basic RNG operations: `random_seed`,
  `random_bits`, `random_split`, `random_fold_in`. These primitives
  essentially delegate to the underlying PRNG implementation (directly
  so in their impl rules, and by translating their staged-out form in
  lowering rules).

* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
  conversion from/to the base `uint32` array. We need this backwards
  compatibility, and it's useful for tests.

* Introduce some `vmap`-based helpers to adapt PRNG impls (which
  define basic `random_bits`, `split`, etc. on scalars) to the above
  batch-polymorphic primitives. Most of the primitives are vectorized,
  but `random_fold_in` is a broadcasting binary op.

* Update the `gamma` primitive rules to account for key-element-type
  abstract arrays (nice simplification here).

* Give PRNG implementation short string names ("tags") for IR
  pretty-printing.

* Update `lax.stop_gradient` to handle opaque dtypes.

* Fix up loop MLIR lowering, which assumed that shaped arrays of all
  dtypes have the same physical shape.

* Add new tests (exercising staging, jaxprs, lowerings, ...)

A sample of changes made to rework Python-level PRNG key arrays:

* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
  tracers that carry them.

* Patch (only a subset of) standard device array attributes onto PRNG
  key arrays.

* Implement various conversion handlers (sharding, constant-creation,
  `device_put`).

* Accept PRNG key arrays as input to `lax_numpy.transpose`.

* Update tests and rename some internals.

A sample of extra changes along the way:

* Disallow AD on key-typed arrays in the main API.

* Hoist `random_bits`'s named-shape-handling logic, which used to only
  take place in the threefry PRNG's `random_bits` implementation, up
  to the new `random_bits` traceable, so that we apply it consistently
  across PRNG implementations.

This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.

Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-18 21:46:55 -07:00
Peter Hawkins
71b29b1cc6 Create JAX Enhancement Proposals (JEPs).
Migrate existing design documents to JEPs.
2022-08-08 16:13:58 -04:00
James Bradbury
64eb46a172
Fix RST formatting in random.py docstring 2022-07-11 02:51:35 -07:00
Tamara Norman
bc9c4b77d0 Adjust docs to account for what the actual current RNG behavior is
PiperOrigin-RevId: 459712928
2022-07-08 02:55:36 -07:00
carlosgmartin
ca83a80f95 Added random.generalized_normal and random.ball. 2022-06-03 15:11:29 -04:00
Carlos Martin
b276c31b75 Added random.orthogonal. 2022-04-29 14:20:50 -04:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Jeppe Klitgaard
342923335a fix: explicit reexport
Explicitly reexports PRNGKeyArray as KeyArray in accordance with PEP 484

See also: https://github.com/python/mypy/issues/11706
2022-04-19 17:54:32 +01:00
Jean-Baptiste
46a666c448 Improve the random module documentation. 2022-04-15 10:49:52 +02:00