Jake VanderPlas
ba8877789d
Roll back https://github.com/jax-ml/jax/pull/28022 due to test breakages.
...
Reverts b336daf747940301de5956dce4ebe790298e6b5b
PiperOrigin-RevId: 747988862
2025-04-15 13:00:04 -07:00
Jake VanderPlas
c56cf4f68d
jax.random.bernoulli: use higher-resolution sampler
2025-04-15 08:18:47 -07:00
Peter Hawkins
e02faabfb2
Replace references to jax.readthedocs.io with docs.jax.dev.
...
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Roy Frostig
97cecdf862
add an out_sharding
option to jax.random.truncated_normal
...
Drop into `Auto` mode in the implementation.
2025-04-03 22:34:08 -07:00
Roy Frostig
f8bbe98a86
require out_shardings
as a keyword-only argument on public functions
...
PiperOrigin-RevId: 743753215
2025-04-03 17:26:05 -07:00
Roy Frostig
bbdea54ccb
add an out_sharding
option to jax.random.permutation
...
Drop into `Auto` mode in the implementation.
2025-04-03 16:21:45 -07:00
Roy Frostig
ab816ed8c4
add an out_sharding
option to jax.random.randint
...
Drop into `Auto` mode in the implementation.
2025-04-02 21:05:19 -07:00
Roy Frostig
2f617631fb
use common maybe_auto_axes
helper in random.uniform
2025-04-02 17:47:25 -07:00
Roy Frostig
2540fcde11
add an out_sharding
option to jax.random.bits
...
Drop into `Auto` mode in the implementation.
2025-04-02 17:19:57 -07:00
Yash Katariya
2e16367991
Remove the extra stack frame that was introduce in uniform due to dropping the entire function in auto axes.
...
PiperOrigin-RevId: 743148311
2025-04-02 08:30:27 -07:00
Yash Katariya
cc51412019
[sharding_in_types] Add out_sharding to jax.random.normal
.
...
Drop into `Auto` mode inside for implementation.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 740538785
2025-03-25 17:12:39 -07:00
Yash Katariya
087a38988c
[sharding_in_types] Add out_sharding
to jax.random.uniform
.
...
Drop into `Auto` mode inside for implementation.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 740529498
2025-03-25 16:42:19 -07:00
carlosgmartin
b5c467e6cf
Fix doc for random.categorical replace argument.
2025-03-19 23:56:24 -04:00
carlosgmartin
3f59fa6888
Add replace option to random.categorical to enable sampling without replacement.
2025-03-17 13:41:46 -04:00
Martin Muller
4a82fe94de
Use lax.top_k
instead of jnp.argsort
in Gumbel top-k trick for weighted sampling without replacement in jax.random.choice
2025-03-14 19:02:24 +01:00
jax authors
bf829ff612
Merge pull request #26524 from carlosgmartin:random_multinomial
...
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
carlosgmartin
6b69a136aa
Add jax.random.multinomial.
2025-03-12 18:15:14 -04:00
carlosgmartin
8b6ca56417
Fix the ValueError message for random.binomial (forgot to use string formatting).
2025-03-10 16:38:03 -04:00
Parker Schuh
b8b690e594
Add use_high_dynamic_range_gumbel flag which allows sampling gumbel such
...
that it more closely matches the CDF for low probably events (less than
2**-nmant).
Because -log(-log(x)) is more sensitive close to 1 than 0, we must use
-log(-logp1(-x)) instead to make better use of the extra range around 0.
PiperOrigin-RevId: 732757388
2025-03-02 19:42:40 -08:00
carlosgmartin
ba428d8cda
Extend random.orthogonal to semi-orthogonal matrices. Simplify initializers.orthogonal by using it.
2025-02-26 16:39:45 -05:00
Jake VanderPlas
e4dac395a5
Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
...
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653
Reverts 95535df13b422284043623ca3a6d2a5962116fb1
PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
carlosgmartin
c478f44e9d
Simplify implementation of random.orthogonal.
2025-02-03 15:02:17 -05:00
carlosgmartin
32411a430f
Add jax.random.multinomial.
2025-01-31 18:45:55 -05:00
Jake VanderPlas
216bd9a6cc
Fix dtype issue in stirling approximation
2025-01-31 14:13:02 -08:00
Roy Frostig
b83049638b
fix gamma_p
in vmap-based impl rule mode
2024-12-13 15:32:09 -08:00
Jake VanderPlas
40367a9eaf
Cleanup: remove uses of no-op raise_to_shaped
2024-12-12 09:49:06 -08:00
Jake VanderPlas
fee272e550
Remove internal KeyArray alias
...
This was useful during the transition to typed PRNG keys, but
is no longer necessary. It also makes generated HTML docs
confusing: it's better to just use Array as we expect users to.
2024-11-20 10:30:12 -08:00
Roy Frostig
4bb81075bc
represent random.key_impl
of builtin RNGs by canonical string name
...
We do not have great reason to return specs here, and sticking to
strings instead can help with simple serialization.
2024-11-19 20:58:10 -08:00
George Necula
e5f4be5564
[shape_poly] Expands support for random.choice
...
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.
The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.
Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
2024-10-24 17:20:09 +03:00
Michael Hudgins
d4d1518c3d
Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
...
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
jax authors
720dfd7e43
Merge pull request #23257 from carlosgmartin:random-orthogonal-citation
...
PiperOrigin-RevId: 671156668
2024-09-04 17:24:51 -07:00
carlosgmartin
5b8c8dd12c
Add citation for random.orthogonal.
2024-09-04 14:43:54 -04:00
Jake VanderPlas
d6394c0795
random.key_impl: improve repr of output
2024-09-04 10:10:31 -07:00
Sergei Lebedev
59f825a23e
Fixed the return type of `jax.random.key_impl
`
...
Closes #23363 .
2024-09-02 21:49:53 +01:00
Roy Frostig
dd535d88a7
emphasize typed over legacy RNG keys in random
module docs
...
Update both docstrings and move the `PRNGKey` function listing lower
in the API reference.
2024-08-11 12:41:50 -07:00
Jake VanderPlas
a17c8d945b
Finalize deprecation of jax.random.shuffle
...
This has been raising a DeprecationWarning for longer than anyone can remember.
PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
Eugene Zhulenev
15d4389247
Use vmap for random_gamma implementation on CPU backend
...
XLA:CPU is preparing to switch from compiling whole XLA program into a single LLVM function to a mode where each fusion/kernel will have its own entry point, and a thin runtime that will dispatch compute functions concurrently. This execution mode does not work very well with while loops with tiny computations and large number of iterations. Similar to GPU backend use vmap to avoid excessive runtime overheads.
Context: https://github.com/openxla/community/pull/96
PiperOrigin-RevId: 656199716
2024-07-25 19:41:59 -07:00
Matthew Johnson
3f9eb404e4
remove named_shapes (since xmap is now gone)
2024-07-25 00:54:50 +00:00
Dan Foreman-Mackey
6d35b109fd
Rename "Example" to "Examples" in docstrings.
...
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
jax authors
f51af87fc5
fp8 matmul in pallas
...
PiperOrigin-RevId: 641254832
2024-06-07 08:17:06 -07:00
carlosgmartin
de8a0d3be6
Add default shape to random.rademacher.
2024-06-04 00:16:58 -04:00
Jake VanderPlas
568987af23
Finalize deprecation of batched keys to PRNG functions
...
PiperOrigin-RevId: 636196573
2024-05-22 09:40:32 -07:00
Sergei Lebedev
f5617d7323
Removed noop # type: ignore comments
...
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
jax authors
d05d29d889
Merge pull request #21050 from rajasekharporeddy:test_branch3
...
PiperOrigin-RevId: 630339307
2024-05-03 03:24:22 -07:00
rajasekharporeddy
ccabdb29ea
Fix typos in docs and an error message
2024-05-03 12:08:22 +05:30
rajasekharporeddy
a0b93153ca
Fix Typos and math rendering in jax.random docs
2024-05-02 17:43:37 +05:30
jax authors
0302e4c34d
Merge pull request #17741 from froystig:new-style-key-docs
...
PiperOrigin-RevId: 614080080
2024-03-08 16:41:22 -08:00
Jake VanderPlas
d1e49f9c89
[key reuse] fix random_clone impl rule
2024-03-08 15:16:39 -08:00
jax authors
c4cf265f86
Merge pull request #20094 from froystig:vmap-rbg
...
PiperOrigin-RevId: 614034982
2024-03-08 13:52:01 -08:00
Roy Frostig
29edfd8925
define a loop-free untrue batching rule for rng_bit_generator
2024-03-08 13:13:03 -08:00