Before this change, primitives have a special case dispatch path that attempts
to avoid building a jaxpr in the cache miss case. However, there's no good
reason for this: it makes the code more complicated, and we're not particularly
optimizing for fast cache misses anyway (we care mostly about cache hits).
Make the primitive lowering path trace a small function using the xla_callable
lowering path instead.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.
A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:
1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
back to it from the functions in random.
It's the exact same code as for JIT. We just modify the Python function to accept ShardedDeviceArray in addition to DeviceArray objects. The test is updated accordingly.
PiperOrigin-RevId: 391272270
This **will** be a **breaking** change, as pxla.ShardedDeviceArray constructor won't be valid anymore:
- for the next Jax release
- on the condition _USE_EXPERIMENTAL_CPP_SDA is switch to `_xla_extension_version > xx` and with the associated jaxlib release.
I am already adding the impact for the users in the CHANGELOG, we can still move it to the next version depending on when it's shipped.
Similarly to JAX.jit, for which we have a C++ `DeviceArray` and a Python `_DeviceArray`, we will introduce 2 objects for ShardedDeviceArray, with the Python object only for JAX extensions not compatible with the C++ object (e.g. Cloud TPU).
- Add `make_sharded_device_array` to be used within JAX and for hackers that need to construct SDA objects.
- Make sure the C++ object is valid by
(a) extending `DeviceArrayBase` (done in Python), as it brings a bunch of methods and enable `isinstance(x, DeviceArray)`
(b) Adding the same methods as the Python SDA.
NOTE: mypy has troubled with the " -> pxla.ShardedDeviceArray` function return type annotation, I had to remove 2.
PiperOrigin-RevId: 389876734
This is backward compatible, as the new objects has the same attributes with the same type (in particular, it can be constructed from iterable objects, and `sharding` and `mesh_mapping` are still tuples.
PiperOrigin-RevId: 388565058
This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.
PiperOrigin-RevId: 384897270
There are a few test cases that generate millions of configurations,
only to have a handful of them selected by `cases_form_list`. I've
found all tests that spend over 100ms in case generation and
converted them to a new "test sampler" approach. The result: test
generation time drops from 15s to around 2s. Doesn't sound like much,
but I expect that we all run tests many times daily, so it seems like a
useful thing to have.
The rough idea is that the sampling generators get parameterized by a
sampler function that should be applied to the range of every `for` loop.
This allows us to sample runs of the generator through different
configurations by restricting each loop to a smaller subset. Right now
we always narrow it down to a single randomly selected instance. But,
we still retain the possibility of adding exhaustive testing in the
future, which can be achieved by passing in an identity sampling
function that wouldn't modify any loop ranges.
This has the benefit of limiting the insane axis arithmetic (with some
axes getting removed, and others introduced with their positions offset
by the removals) to the all_to_all user-facing function, but all the
collective rules should now be simpler to write. This should be a no-op
from the point of view of the users, but should make enabling all_to_all
splitting easier.
Change the structure of `execute_replicated` so that `in_handlers` and
`out_handlers` return and take `args[arg][shard]`
instead of `args[shard][arg]`.
This is an expansion of the first, rolled-back attempt (https://github.com/google/jax/pull/5260), this time with auto-diff and batching rules that some users are relying on.
My benchmarks suggest a speed-up of ~2-2.5x for larger inputs.