188 Commits

Author SHA1 Message Date
carlosgmartin
ca83a80f95 Added random.generalized_normal and random.ball. 2022-06-03 15:11:29 -04:00
Jake VanderPlas
2b6387f83f [x64] make jax.random compatible with jax_numpy_dtype_promotion=strict 2022-05-27 11:12:39 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Rohit Santhanam
8d9f17df19 Disabled one and enabled several unit tests for ROCm. 2022-05-10 19:47:26 +00:00
jax authors
227e525de2 Merge pull request #10458 from carlosgmartin:random_orthogonal_unitary
PiperOrigin-RevId: 445522278
2022-04-29 15:40:16 -07:00
Carlos Martin
b276c31b75 Added random.orthogonal. 2022-04-29 14:20:50 -04:00
YouJiacheng
4ff6b1fbca Fix PRNGKeyArray.broadcast_to with scalar shape 2022-04-16 00:27:30 +08:00
Joan Puigcerver
0c02f7935a Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests.
PiperOrigin-RevId: 440300882
2022-04-08 01:08:55 -07:00
Jake VanderPlas
e68e87ac22 random_test: add tests of random values for distributions 2022-03-23 10:34:26 -07:00
Jake VanderPlas
69969ef803 add random.loggamma and improve dirichlet & beta implementation 2022-03-21 08:33:11 -07:00
Joan Puigcerver
caf094d06b Support gamma distribution with PRNGKeys other than threefry2x32.
PiperOrigin-RevId: 433614014
2022-03-09 17:06:02 -08:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Jake VanderPlas
2c2773a5f1 jax.random.poisson: fix corner cases 2022-02-28 12:10:47 -08:00
Roy Frostig
88c6b84daf in tests, compare jnp operations on PRNGKeyArrays to the same on jnp arrays 2022-02-17 11:43:09 -08:00
Roy Frostig
0f7904f883 implement jnp.expand_dims and jnp.stack for PRNGKeyArrays
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
2022-02-16 20:47:27 -08:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Roy Frostig
be48964834 generate bit widths other than 32 in lax.rng_bit_generator
Three parts:

* The underlying XLA operation (RngBitGenerator) doesn't support
  generating bit widths 8 and 16, so generate 32 bits and truncate in
  the translation rule.

* Canonicalize the dtype given to `rng_bit_generator` to avoid
  requests for U64s when x64 mode is off.

* Test the effect of this on PRNG implementations backed by
  `rng_bit_generator`. Namely, their `random_bits` method should now
  support all bit widths, and their keys can be used in samplers such
  as `random.uniform` and `random.randint` to generate 16-bit floats,
  and {8,16}-bit integers respectively.
2022-01-21 08:45:37 -08:00
Roy Frostig
026b91b85d add random.default_prng_impl to retrieve the default PRNG implementation 2022-01-12 19:13:14 -08:00
Rakshit
03adf8cc89 support null shape in jax.random.poisson 2021-12-29 15:31:15 +05:30
Jake VanderPlas
4d9e9b4986 custom_prng: generalize indexing of PRNGKeyArray
Co-authored-by: Roy Frostig <frostig@google.com>
2021-12-20 10:16:32 -08:00
jax authors
d03cc4f84a Merge pull request #8886 from jakevdp:random-test-x32
PiperOrigin-RevId: 415599408
2021-12-10 13:35:42 -08:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Jake VanderPlas
69fdb306de [x64] make random_test pass with jax_default_dtype_bits=32 2021-12-09 16:57:29 -08:00
Lena Martens
e14fea3b63 Overload jnp ops which are polymorphic to an array's value and support PRNGKeys. 2021-11-16 23:00:32 +00:00
Jake VanderPlas
734a91350b jax.random.permutation: add independent keyword 2021-11-02 11:39:41 -07:00
Julius Kunze
1934fd6e65 Cleanup random.permutation 2021-10-13 14:13:00 -06:00
jax authors
10af170a85 Merge pull request #8161 from juliuskunze:multidim-permutation
PiperOrigin-RevId: 402852030
2021-10-13 09:31:19 -07:00
Julius Kunze
63898b6ca6 Allow random.choice and random.permutation on multidimensional arrays 2021-10-13 09:39:25 -06:00
Roy Frostig
ba370f8c86 hide keys attribute of PRNGKeyArray in favor of unsafe_raw_array 2021-10-12 08:47:29 -07:00
Roy Frostig
98d245ebb4 add a config setting to control the default PRNG implementation
Also add explicit seeding functions for each PRNG implementation.
2021-10-07 21:22:40 -07:00
Matthew Johnson
634d252bb3 improvements to RBG PRNG
1. factor out rbg_prng_impl and unsafe_rbg_prng_impl. the former uses
   threefry2x32 for split and fold_in, while the latter uses untested
   heuristics based on calling rng_bit_generator itself as a kind of
   hash function
2. for unsafe_rbg_prng_impl's split and fold_in, generate longer
   sequences from rng_bit_generator (10x iterations) which may be useful on
   some backends
3. for unsafe_rbg_prng_impl, actually apply rng_bit_generator as our
   'hash function' in fold_in

Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
2021-10-07 18:59:13 -07:00
Lena Martens
342948dcc4 Add batching rule for rng_bit_generator. 2021-10-06 17:50:07 +01:00
Matthew Johnson
980dfcfd2c add experimental RngBitGenerator ("RBG") PRNG
Not only is this an experimental API, but also because the
RngBitGenerator is not guaranteed to be stable across compiler versions
(or backends or shardings), let's assert that this JAX PRNG
implementation may not be stable across JAX versions.

Even without that kind of stability, this PRNG is still useful because
compared to effectful RNG primitives, like lax.rng_uniform, this RBG
PRNG will still work correctly with lax.scan and jax.checkpoint (while
still potentially being more performant on some platforms than JAX's
standard PRNG).
2021-10-01 19:04:14 -07:00
Jake VanderPlas
1830aff2b0 random.multivariate_normal: fix broadcasting for svd & eigh methods 2021-09-28 09:58:08 -07:00
Lena Martens
327e00a668 PRNGKeys can have dtype float0 if they are a tangent. 2021-09-24 16:37:56 +01:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
9f083d11da Use jax.* APIs rather than api.* names in tests.
Tests should use our own public APIs where they exist.
2021-09-13 16:01:32 -04:00
Roy Frostig
ab544cb26d flip lower bits to derive seeds for the custom test PRNG
Changing bits, instead of incrementing, fixes an inherited test that seeds
with a maximum-valued numpy.uint64.

PiperOrigin-RevId: 396061086
2021-09-10 21:10:13 -07:00
Roy Frostig
2d7a98beaf make PRNGKeyArray.shape a property, add tests for it and for disallowed addition 2021-09-10 18:45:18 -07:00
Roy Frostig
60e0e9f929 implement backwards-compatible behavior and enable custom PRNGs only conditionally
Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.

Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
2021-08-19 20:43:11 -07:00
Roy Frostig
ad4b4544df introduce custom PRNG tests
Now that `LaxRandomTest` only tests random functions and not any
specific PRNG, it can be reused to test random functions under
different PRNGs.
2021-08-19 20:43:11 -07:00
Roy Frostig
e33ef7f2e7 factor prng tests from lax random tests
Some tests check the behavior of the random bit generator---in
particular the default threefry implementation---and some check the
behavior of samplers. Separate them into different test classes.
2021-08-19 20:43:11 -07:00
Roy Frostig
aa265cce95 introduce custom PRNG implementations and an array-like adapter for them
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.

A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:

1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
   back to it from the functions in random.
2021-08-19 20:43:11 -07:00
Jake VanderPlas
6114e6a0d3 test_util: add decorator to set config values in test cases 2021-08-05 14:06:37 -07:00
Peter Hawkins
b232d09440 Enable flake8 checks for spaces around operators. 2021-07-30 08:45:38 -04:00
Jake VanderPlas
233d9f79c6 Run random_test with rank_promotion='raise' 2021-07-08 11:20:05 -07:00
Amol Mandhane
f945982b8c Broadcast arrays manually in categorical sampling. 2021-07-07 19:56:33 +01:00
Jake VanderPlas
71bf410055 Fix random test for scipy>=1.17.0 2021-06-21 13:44:54 -07:00