69 Commits

Author SHA1 Message Date
Jake VanderPlas
bcb45557c4 random.choice: make compatible with strict promotion 2022-06-28 10:41:30 -07:00
carlosgmartin
ca83a80f95 Added random.generalized_normal and random.ball. 2022-06-03 15:11:29 -04:00
Jake VanderPlas
2b6387f83f [x64] make jax.random compatible with jax_numpy_dtype_promotion=strict 2022-05-27 11:12:39 -07:00
Carlos Martin
b276c31b75 Added random.orthogonal. 2022-04-29 14:20:50 -04:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
b9bb61322c [MHLO] Prefer backend-specific HLO lowerings instead of non-backend-specific MHLO lowerings.
This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet.

Add a handful of direct MLIR lowerings for primitives that lacked them.

PiperOrigin-RevId: 439912093
2022-04-06 12:53:56 -07:00
Jake VanderPlas
69969ef803 add random.loggamma and improve dirichlet & beta implementation 2022-03-21 08:33:11 -07:00
Joan Puigcerver
caf094d06b Support gamma distribution with PRNGKeys other than threefry2x32.
PiperOrigin-RevId: 433614014
2022-03-09 17:06:02 -08:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
Jake VanderPlas
00e040e514 cleanup: remove _constant_like in favor of lax._const 2022-03-02 09:13:58 -08:00
Jake VanderPlas
2c2773a5f1 jax.random.poisson: fix corner cases 2022-02-28 12:10:47 -08:00
michaelmarien
20e5090b61 Add a warning to random.choice to notify users of the ill-defined behaviour when requesting more samples than non-zero probabilities and replace=False 2022-02-14 21:41:30 +01:00
Jake VanderPlas
e207954259 Remove unused _asarray utility 2022-01-31 12:47:49 -08:00
Roy Frostig
026b91b85d add random.default_prng_impl to retrieve the default PRNG implementation 2022-01-12 19:13:14 -08:00
Rakshit
03adf8cc89 support null shape in jax.random.poisson 2021-12-29 15:31:15 +05:30
Jake VanderPlas
f6e3f1b4ad Cleanup: remove duplicate canonicalize_axis utility 2021-11-23 16:54:02 -08:00
Jake VanderPlas
734a91350b jax.random.permutation: add independent keyword 2021-11-02 11:39:41 -07:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
Julius Kunze
1934fd6e65 Cleanup random.permutation 2021-10-13 14:13:00 -06:00
jax authors
10af170a85 Merge pull request #8161 from juliuskunze:multidim-permutation
PiperOrigin-RevId: 402852030
2021-10-13 09:31:19 -07:00
Julius Kunze
63898b6ca6 Allow random.choice and random.permutation on multidimensional arrays 2021-10-13 09:39:25 -06:00
Roy Frostig
ba370f8c86 hide keys attribute of PRNGKeyArray in favor of unsafe_raw_array 2021-10-12 08:47:29 -07:00
Roy Frostig
98d245ebb4 add a config setting to control the default PRNG implementation
Also add explicit seeding functions for each PRNG implementation.
2021-10-07 21:22:40 -07:00
Jake VanderPlas
1830aff2b0 random.multivariate_normal: fix broadcasting for svd & eigh methods 2021-09-28 09:58:08 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Matthew Johnson
e416e87301 inline jit-decorated jax.random calls 2021-08-20 13:43:38 -07:00
Roy Frostig
4eb437a568 alias prng.threefry_2x32 in random and warn of move
Some call this, apparently.
2021-08-19 20:43:11 -07:00
Roy Frostig
60e0e9f929 implement backwards-compatible behavior and enable custom PRNGs only conditionally
Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.

Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
2021-08-19 20:43:11 -07:00
Roy Frostig
aa265cce95 introduce custom PRNG implementations and an array-like adapter for them
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.

A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:

1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
   back to it from the functions in random.
2021-08-19 20:43:11 -07:00
Roy Frostig
b4ccecca88 factor PRNG routines from random module to prng 2021-08-17 19:27:31 -07:00
George Necula
0b766b27a4 [jax2tf] Improved testing and shape polymorphism support for lax.random. 2021-07-29 16:07:13 +03:00
Amol Mandhane
f945982b8c Broadcast arrays manually in categorical sampling. 2021-07-07 19:56:33 +01:00
Peter Hawkins
75c9bf01f3 Fix most test failures under NumPy 1.21. 2021-06-22 16:31:44 -04:00
Jake VanderPlas
d20ecf5db8 random: generate large sequences via vmap rather than Python loops 2021-06-17 16:10:30 -07:00
Jake VanderPlas
119c9bc0dd jax.random: improve input validation (fixes #6922) 2021-06-08 13:37:21 -07:00
Lukas Geiger
3a2e80ef51 Replace pow() with srqt() or square() where possible 2021-05-24 10:43:35 +01:00
Neil Girdhar
d724a30831 Use Array instead of jnp.array in jax.random
This satisfies MyPy for the functions:
* jax.random.categorical,
* jax.random.shuffle, and
* jax.random.permutation.
2021-05-11 01:08:36 -04:00
Jake VanderPlas
a77c96cd12 Cleanup: remove unnecessary utility 2021-05-07 10:14:34 -07:00
jax authors
0f96406130 Merge pull request #6461 from apaszke:xmap-awn
PiperOrigin-RevId: 369208554
2021-04-19 06:36:34 -07:00
Neil Girdhar
e827345202 Make Weibull, Maxwell sampling parameters nonstatic
These parameters cannot be static since numpy arrays are not hashable.
2021-04-13 17:01:29 -04:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Jake VanderPlas
5b9ea5b74d fix random.permutation for empty inputs 2021-04-12 17:00:51 -07:00
Peter Hawkins
8a450c42a7 Silence some mypy errors seen with Python 3.9 and Numpy 1.20.
None of these seem like real errors, but making mypy happy doesn't make the code much worse.
2021-04-08 11:08:45 -04:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Jake VanderPlas
2090431ba5 random.randint: support generating the full range of dtype 2021-03-31 15:49:03 -07:00
Jake VanderPlas
640e62c7da Rollback #6293
PiperOrigin-RevId: 366119851
2021-03-31 14:43:23 -07:00
Jake VanderPlas
f0ff665eaf random.randint: clip rather than wrap out-of-bounds min/max 2021-03-31 10:01:23 -07:00
Jake VanderPlas
c11e725ecb X32 mode: raise OverflowError for large integers 2021-03-30 10:05:03 -07:00
Jake VanderPlas
9790232556 Python integer conversion: always return int64 or OverflowError 2021-03-29 09:26:19 -07:00
Neil Girdhar
af0988ee7f Annotate scatter and random 2021-03-24 20:06:17 -04:00