For example:
```
k1, k2 = jax.random.split(key, 2) # where key is fully replicated on 8 devices
```
Then `k1` and `k2` should also maintain the sharding of `key` since `key` is fully replicated.
PiperOrigin-RevId: 480434272
It's a bit of a weird use-case, since MANUAL mode is meant for xmaps that
are nested inside pjits, but it doesn't hurt us to support it.
PiperOrigin-RevId: 480342531
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).
Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:
Migrate more tests from jtu.cases_from_list to jtu.sample_product.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.
PiperOrigin-RevId: 480134780
Why?
1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)
2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.
PiperOrigin-RevId: 479850850
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.
PiperOrigin-RevId: 479670565