The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
This touches _gather_batching_rule because slicing is implemented as a
gather, but we only test the case exercised by the slice that occurs
in our test transformer model, namely the unstack operation
q, k, v = qkv
(which turns into three slices on an non-batched and non-ragged axis).
Co-authored-by: Matthew Johnson <mattjj@google.com>
This is an incremental change to our random tests that primarily:
* Increases test coverage of both key constructors (`random.key` and
`random.PRNGKey`), often by parameterizing tests over both.
* Increases test coverage of both key representations (typed key
arrays and `uint32` arrays).
* Removes a handful of guards on `config.jax_enable_custom_prng`,
either replacing them with `isinstance` checks for typed keys or
removing them altogether if possible.
* Makes a handful of other individual test improvements and fixes, and
leaves comments for more.
This change primarily adds an optional argument to both old- and
new-style random key constructors. The option determines the PRNG
implementation for the key by name, overriding any default
implementation determined by configuration flags.
Along the way, looking ahead:
* We can deprecate the (anyway underused) individual explicit key
constructors like `jax.random.threefr2x32_key` in favor of this
option.
* Some day, instead of only accepting RNG implementations by name
(string), we can also accept the output of some custom PRNG
implementation API that we expose, maybe via `jax.extend.random`
(corresponding roughly to the current `_src.prng.PRNGImpl`).
If both the second and third operand of a `lax.cond` call are callable, then
resolve it as a new-style (default) conditional, where both branches act on the
same operands.
This changes the behavior of five-argument `lax.cond` calls. It is a breaking
change for callers using the old-style `cond` calling convention (`pred`,
`true_arg`, `true_fn`, `false_arg`, `false_fn`) with a callable `true_arg`.
PiperOrigin-RevId: 543912445
sharding=None means that JAX is free to choose whatever sharding it wants. As it stands, jax will choose to mark the input as replicated but JAX reserves the right to change that as it sees fit.
PiperOrigin-RevId: 543630595
This is similar to how send/receive callback are implemented.
Update make_c_api_client to take key value get/put callback generated from distributed client, and optiosn of node_id and num_nodes.
PiperOrigin-RevId: 543441403
We're now moving to a world where custom PRNG should exist side-by-side with the old PRNG
implementation. This change improves test coverage for that, by enabling relevant tests
even when the flag is set to False.
* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
LoadedExecutable.cost_analysis, then fallback to the client method.
PiperOrigin-RevId: 542671990