213 Commits

Author SHA1 Message Date
Jake VanderPlas
ba8877789d Roll back https://github.com/jax-ml/jax/pull/28022 due to test breakages.
Reverts b336daf747940301de5956dce4ebe790298e6b5b

PiperOrigin-RevId: 747988862
2025-04-15 13:00:04 -07:00
Jake VanderPlas
c56cf4f68d jax.random.bernoulli: use higher-resolution sampler 2025-04-15 08:18:47 -07:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Roy Frostig
97cecdf862 add an out_sharding option to jax.random.truncated_normal
Drop into `Auto` mode in the implementation.
2025-04-03 22:34:08 -07:00
Roy Frostig
f8bbe98a86 require out_shardings as a keyword-only argument on public functions
PiperOrigin-RevId: 743753215
2025-04-03 17:26:05 -07:00
Roy Frostig
bbdea54ccb add an out_sharding option to jax.random.permutation
Drop into `Auto` mode in the implementation.
2025-04-03 16:21:45 -07:00
Roy Frostig
ab816ed8c4 add an out_sharding option to jax.random.randint
Drop into `Auto` mode in the implementation.
2025-04-02 21:05:19 -07:00
Roy Frostig
2f617631fb use common maybe_auto_axes helper in random.uniform 2025-04-02 17:47:25 -07:00
Roy Frostig
2540fcde11 add an out_sharding option to jax.random.bits
Drop into `Auto` mode in the implementation.
2025-04-02 17:19:57 -07:00
Yash Katariya
2e16367991 Remove the extra stack frame that was introduce in uniform due to dropping the entire function in auto axes.
PiperOrigin-RevId: 743148311
2025-04-02 08:30:27 -07:00
Yash Katariya
cc51412019 [sharding_in_types] Add out_sharding to jax.random.normal.
Drop into `Auto` mode inside for implementation.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 740538785
2025-03-25 17:12:39 -07:00
Yash Katariya
087a38988c [sharding_in_types] Add out_sharding to jax.random.uniform.
Drop into `Auto` mode inside for implementation.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 740529498
2025-03-25 16:42:19 -07:00
carlosgmartin
b5c467e6cf Fix doc for random.categorical replace argument. 2025-03-19 23:56:24 -04:00
carlosgmartin
3f59fa6888 Add replace option to random.categorical to enable sampling without replacement. 2025-03-17 13:41:46 -04:00
Martin Muller
4a82fe94de Use lax.top_k instead of jnp.argsort in Gumbel top-k trick for weighted sampling without replacement in jax.random.choice 2025-03-14 19:02:24 +01:00
jax authors
bf829ff612 Merge pull request #26524 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
carlosgmartin
6b69a136aa Add jax.random.multinomial. 2025-03-12 18:15:14 -04:00
carlosgmartin
8b6ca56417 Fix the ValueError message for random.binomial (forgot to use string formatting). 2025-03-10 16:38:03 -04:00
Parker Schuh
b8b690e594 Add use_high_dynamic_range_gumbel flag which allows sampling gumbel such
that it more closely matches the CDF for low probably events (less than
2**-nmant).

Because -log(-log(x)) is more sensitive close to 1 than 0, we must use
-log(-logp1(-x)) instead to make better use of the extra range around 0.

PiperOrigin-RevId: 732757388
2025-03-02 19:42:40 -08:00
carlosgmartin
ba428d8cda Extend random.orthogonal to semi-orthogonal matrices. Simplify initializers.orthogonal by using it. 2025-02-26 16:39:45 -05:00
Jake VanderPlas
e4dac395a5 Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653

Reverts 95535df13b422284043623ca3a6d2a5962116fb1

PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
carlosgmartin
c478f44e9d Simplify implementation of random.orthogonal. 2025-02-03 15:02:17 -05:00
carlosgmartin
32411a430f Add jax.random.multinomial. 2025-01-31 18:45:55 -05:00
Jake VanderPlas
216bd9a6cc Fix dtype issue in stirling approximation 2025-01-31 14:13:02 -08:00
Roy Frostig
b83049638b fix gamma_p in vmap-based impl rule mode 2024-12-13 15:32:09 -08:00
Jake VanderPlas
40367a9eaf Cleanup: remove uses of no-op raise_to_shaped 2024-12-12 09:49:06 -08:00
Jake VanderPlas
fee272e550 Remove internal KeyArray alias
This was useful during the transition to typed PRNG keys, but
is no longer necessary. It also makes generated HTML docs
confusing: it's better to just use Array as we expect users to.
2024-11-20 10:30:12 -08:00
Roy Frostig
4bb81075bc represent random.key_impl of builtin RNGs by canonical string name
We do not have great reason to return specs here, and sticking to
strings instead can help with simple serialization.
2024-11-19 20:58:10 -08:00
George Necula
e5f4be5564 [shape_poly] Expands support for random.choice
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.

The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.

Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
2024-10-24 17:20:09 +03:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
jax authors
720dfd7e43 Merge pull request #23257 from carlosgmartin:random-orthogonal-citation
PiperOrigin-RevId: 671156668
2024-09-04 17:24:51 -07:00
carlosgmartin
5b8c8dd12c Add citation for random.orthogonal. 2024-09-04 14:43:54 -04:00
Jake VanderPlas
d6394c0795 random.key_impl: improve repr of output 2024-09-04 10:10:31 -07:00
Sergei Lebedev
59f825a23e Fixed the return type of `jax.random.key_impl`
Closes #23363.
2024-09-02 21:49:53 +01:00
Roy Frostig
dd535d88a7 emphasize typed over legacy RNG keys in random module docs
Update both docstrings and move the `PRNGKey` function listing lower
in the API reference.
2024-08-11 12:41:50 -07:00
Jake VanderPlas
a17c8d945b Finalize deprecation of jax.random.shuffle
This has been raising a DeprecationWarning for longer than anyone can remember.

PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
Eugene Zhulenev
15d4389247 Use vmap for random_gamma implementation on CPU backend
XLA:CPU is preparing to switch from compiling whole XLA program into a single LLVM function to a mode where each fusion/kernel will have its own entry point, and a thin runtime that will dispatch compute functions concurrently. This execution mode does not work very well with while loops with tiny computations and large number of iterations. Similar to GPU backend use vmap to avoid excessive runtime overheads.

Context: https://github.com/openxla/community/pull/96
PiperOrigin-RevId: 656199716
2024-07-25 19:41:59 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
jax authors
f51af87fc5 fp8 matmul in pallas
PiperOrigin-RevId: 641254832
2024-06-07 08:17:06 -07:00
carlosgmartin
de8a0d3be6 Add default shape to random.rademacher. 2024-06-04 00:16:58 -04:00
Jake VanderPlas
568987af23 Finalize deprecation of batched keys to PRNG functions
PiperOrigin-RevId: 636196573
2024-05-22 09:40:32 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
jax authors
d05d29d889 Merge pull request #21050 from rajasekharporeddy:test_branch3
PiperOrigin-RevId: 630339307
2024-05-03 03:24:22 -07:00
rajasekharporeddy
ccabdb29ea Fix typos in docs and an error message 2024-05-03 12:08:22 +05:30
rajasekharporeddy
a0b93153ca Fix Typos and math rendering in jax.random docs 2024-05-02 17:43:37 +05:30
jax authors
0302e4c34d Merge pull request #17741 from froystig:new-style-key-docs
PiperOrigin-RevId: 614080080
2024-03-08 16:41:22 -08:00
Jake VanderPlas
d1e49f9c89 [key reuse] fix random_clone impl rule 2024-03-08 15:16:39 -08:00
jax authors
c4cf265f86 Merge pull request #20094 from froystig:vmap-rbg
PiperOrigin-RevId: 614034982
2024-03-08 13:52:01 -08:00
Roy Frostig
29edfd8925 define a loop-free untrue batching rule for rng_bit_generator 2024-03-08 13:13:03 -08:00